1use super::spirv::{
9 EXECUTION_MODEL_KERNEL, FUNCTION_CONTROL_NONE, OP_F_ADD, OP_F_DIV, OP_F_MUL, OP_F_SUB,
10 OP_I_ADD, OP_I_MUL, OP_U_DIV, OP_U_LESS_THAN, OP_U_MOD, OPENCL_EXP, OPENCL_FMAX,
11 STORAGE_CLASS_FUNCTION, SpvModule, WORKGROUP_SIZE, emit_preamble, load_gid_x,
12};
13
14const OP_I_SUB: u32 = 130;
16const OP_F_ORD_GT: u32 = 188;
18
19#[allow(clippy::too_many_arguments)]
35pub fn conv2d_spirv(
36 n: u32,
37 c_in: u32,
38 h_in: u32,
39 w_in: u32,
40 k_out: u32,
41 fh: u32,
42 fw: u32,
43 oh: u32,
44 ow: u32,
45 stride_h: u32,
46 stride_w: u32,
47 pad_h: u32,
48 pad_w: u32,
49) -> Vec<u32> {
50 let mut m = SpvModule::new();
51 let b = emit_preamble(&mut m);
52
53 let main_fn = m.alloc_id();
54 let fn_ty = m.alloc_id();
55 let p_input = m.alloc_id();
56 let p_filter = m.alloc_id();
57 let p_output = m.alloc_id();
58
59 m.emit_type_function(
61 fn_ty,
62 b.ty_void,
63 &[
64 b.ty_ptr_cross_float,
65 b.ty_ptr_cross_float,
66 b.ty_ptr_cross_float,
67 ],
68 );
69
70 m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
71 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
72
73 let c_n = m.alloc_id();
75 m.emit_constant_u32(b.ty_uint, c_n, n);
76 let c_c_in = m.alloc_id();
77 m.emit_constant_u32(b.ty_uint, c_c_in, c_in);
78 let c_h_in = m.alloc_id();
79 m.emit_constant_u32(b.ty_uint, c_h_in, h_in);
80 let c_w_in = m.alloc_id();
81 m.emit_constant_u32(b.ty_uint, c_w_in, w_in);
82 let c_k_out = m.alloc_id();
83 m.emit_constant_u32(b.ty_uint, c_k_out, k_out);
84 let c_fh = m.alloc_id();
85 m.emit_constant_u32(b.ty_uint, c_fh, fh);
86 let c_fw = m.alloc_id();
87 m.emit_constant_u32(b.ty_uint, c_fw, fw);
88 let c_oh = m.alloc_id();
89 m.emit_constant_u32(b.ty_uint, c_oh, oh);
90 let c_ow = m.alloc_id();
91 m.emit_constant_u32(b.ty_uint, c_ow, ow);
92 let c_stride_h = m.alloc_id();
93 m.emit_constant_u32(b.ty_uint, c_stride_h, stride_h);
94 let c_stride_w = m.alloc_id();
95 m.emit_constant_u32(b.ty_uint, c_stride_w, stride_w);
96 let c_pad_h = m.alloc_id();
97 m.emit_constant_u32(b.ty_uint, c_pad_h, pad_h);
98 let c_pad_w = m.alloc_id();
99 m.emit_constant_u32(b.ty_uint, c_pad_w, pad_w);
100
101 let label_entry = m.alloc_id();
103 let label_body = m.alloc_id();
104 let label_merge = m.alloc_id();
105
106 m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
108 m.emit_function_parameter(b.ty_ptr_cross_float, p_input);
109 m.emit_function_parameter(b.ty_ptr_cross_float, p_filter);
110 m.emit_function_parameter(b.ty_ptr_cross_float, p_output);
111 m.emit_label(label_entry);
112
113 let gid = load_gid_x(&mut m, &b);
114
115 let t1 = m.alloc_id();
117 m.emit(OP_I_MUL, &[b.ty_uint, t1, c_n, c_k_out]);
118 let t2 = m.alloc_id();
119 m.emit(OP_I_MUL, &[b.ty_uint, t2, t1, c_oh]);
120 let total = m.alloc_id();
121 m.emit(OP_I_MUL, &[b.ty_uint, total, t2, c_ow]);
122
123 let cond = m.alloc_id();
124 m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, total]);
125 m.emit_selection_merge(label_merge);
126 m.emit_branch_conditional(cond, label_body, label_merge);
127
128 m.emit_label(label_body);
129
130 let ox = m.alloc_id();
132 m.emit(OP_U_MOD, &[b.ty_uint, ox, gid, c_ow]);
133 let tmp1 = m.alloc_id();
134 m.emit(OP_U_DIV, &[b.ty_uint, tmp1, gid, c_ow]);
135 let oy = m.alloc_id();
136 m.emit(OP_U_MOD, &[b.ty_uint, oy, tmp1, c_oh]);
137 let tmp2 = m.alloc_id();
138 m.emit(OP_U_DIV, &[b.ty_uint, tmp2, tmp1, c_oh]);
139 let kf = m.alloc_id();
140 m.emit(OP_U_MOD, &[b.ty_uint, kf, tmp2, c_k_out]);
141 let b_idx = m.alloc_id();
142 m.emit(OP_U_DIV, &[b.ty_uint, b_idx, tmp2, c_k_out]);
143
144 let var_acc = m.alloc_id();
146 m.emit_variable(b.ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
147 m.emit_store(var_acc, b.c_float_0);
148
149 let flat_total_id = m.alloc_id();
151 let flat_t1 = m.alloc_id();
152 m.emit(OP_I_MUL, &[b.ty_uint, flat_t1, c_c_in, c_fh]);
153 m.emit(OP_I_MUL, &[b.ty_uint, flat_total_id, flat_t1, c_fw]);
154
155 let var_flat = m.alloc_id();
156 m.emit_variable(b.ty_ptr_func_uint, var_flat, STORAGE_CLASS_FUNCTION);
157 m.emit_store(var_flat, b.c_uint_0);
158
159 let lbl_loop_hdr = m.alloc_id();
160 let lbl_loop_body = m.alloc_id();
161 let lbl_loop_cont = m.alloc_id();
162 let lbl_loop_merge = m.alloc_id();
163
164 m.emit_branch(lbl_loop_hdr);
165
166 m.emit_label(lbl_loop_hdr);
168 let flat_val = m.alloc_id();
169 m.emit_load(b.ty_uint, flat_val, var_flat);
170 let loop_cond = m.alloc_id();
171 m.emit(
172 OP_U_LESS_THAN,
173 &[b.ty_bool, loop_cond, flat_val, flat_total_id],
174 );
175 m.emit_loop_merge(lbl_loop_merge, lbl_loop_cont);
176 m.emit_branch_conditional(loop_cond, lbl_loop_body, lbl_loop_merge);
177
178 m.emit_label(lbl_loop_body);
180
181 let fx = m.alloc_id();
183 m.emit(OP_U_MOD, &[b.ty_uint, fx, flat_val, c_fw]);
184 let ftmp1 = m.alloc_id();
185 m.emit(OP_U_DIV, &[b.ty_uint, ftmp1, flat_val, c_fw]);
186 let fy = m.alloc_id();
187 m.emit(OP_U_MOD, &[b.ty_uint, fy, ftmp1, c_fh]);
188 let ci = m.alloc_id();
189 m.emit(OP_U_DIV, &[b.ty_uint, ci, ftmp1, c_fh]);
190
191 let oy_sh = m.alloc_id();
193 m.emit(OP_I_MUL, &[b.ty_uint, oy_sh, oy, c_stride_h]);
194 let iy_raw = m.alloc_id();
195 m.emit(OP_I_ADD, &[b.ty_uint, iy_raw, oy_sh, fy]);
196 let ox_sw = m.alloc_id();
197 m.emit(OP_I_MUL, &[b.ty_uint, ox_sw, ox, c_stride_w]);
198 let ix_raw = m.alloc_id();
199 m.emit(OP_I_ADD, &[b.ty_uint, ix_raw, ox_sw, fx]);
200
201 let lbl_skip = m.alloc_id();
204
205 let iy_lt_pad = m.alloc_id();
206 m.emit(OP_U_LESS_THAN, &[b.ty_bool, iy_lt_pad, iy_raw, c_pad_h]);
207 let lbl_iy_ok = m.alloc_id();
208 m.emit_selection_merge(lbl_skip);
209 m.emit_branch_conditional(iy_lt_pad, lbl_skip, lbl_iy_ok);
210
211 m.emit_label(lbl_iy_ok);
212 let iy_real = m.alloc_id();
213 m.emit(OP_I_SUB, &[b.ty_uint, iy_real, iy_raw, c_pad_h]);
214 let iy_in_bounds = m.alloc_id();
215 m.emit(OP_U_LESS_THAN, &[b.ty_bool, iy_in_bounds, iy_real, c_h_in]);
216 let lbl_ix_check = m.alloc_id();
217 m.emit_selection_merge(lbl_skip);
218 m.emit_branch_conditional(iy_in_bounds, lbl_ix_check, lbl_skip);
219
220 m.emit_label(lbl_ix_check);
221 let ix_lt_pad = m.alloc_id();
222 m.emit(OP_U_LESS_THAN, &[b.ty_bool, ix_lt_pad, ix_raw, c_pad_w]);
223 let lbl_ix_ok = m.alloc_id();
224 m.emit_selection_merge(lbl_skip);
225 m.emit_branch_conditional(ix_lt_pad, lbl_skip, lbl_ix_ok);
226
227 m.emit_label(lbl_ix_ok);
228 let ix_real = m.alloc_id();
229 m.emit(OP_I_SUB, &[b.ty_uint, ix_real, ix_raw, c_pad_w]);
230 let ix_in_bounds = m.alloc_id();
231 m.emit(OP_U_LESS_THAN, &[b.ty_bool, ix_in_bounds, ix_real, c_w_in]);
232 let lbl_accum = m.alloc_id();
233 m.emit_selection_merge(lbl_skip);
234 m.emit_branch_conditional(ix_in_bounds, lbl_accum, lbl_skip);
235
236 m.emit_label(lbl_accum);
237
238 let in1 = m.alloc_id();
240 m.emit(OP_I_MUL, &[b.ty_uint, in1, b_idx, c_c_in]);
241 let in2 = m.alloc_id();
242 m.emit(OP_I_ADD, &[b.ty_uint, in2, in1, ci]);
243 let in3 = m.alloc_id();
244 m.emit(OP_I_MUL, &[b.ty_uint, in3, in2, c_h_in]);
245 let in4 = m.alloc_id();
246 m.emit(OP_I_ADD, &[b.ty_uint, in4, in3, iy_real]);
247 let in5 = m.alloc_id();
248 m.emit(OP_I_MUL, &[b.ty_uint, in5, in4, c_w_in]);
249 let in_idx = m.alloc_id();
250 m.emit(OP_I_ADD, &[b.ty_uint, in_idx, in5, ix_real]);
251
252 let inp_ptr = m.alloc_id();
253 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, inp_ptr, p_input, in_idx);
254 let inp_val = m.alloc_id();
255 m.emit_load(b.ty_float, inp_val, inp_ptr);
256
257 let f1 = m.alloc_id();
259 m.emit(OP_I_MUL, &[b.ty_uint, f1, kf, c_c_in]);
260 let f2 = m.alloc_id();
261 m.emit(OP_I_ADD, &[b.ty_uint, f2, f1, ci]);
262 let f3 = m.alloc_id();
263 m.emit(OP_I_MUL, &[b.ty_uint, f3, f2, c_fh]);
264 let f4 = m.alloc_id();
265 m.emit(OP_I_ADD, &[b.ty_uint, f4, f3, fy]);
266 let f5 = m.alloc_id();
267 m.emit(OP_I_MUL, &[b.ty_uint, f5, f4, c_fw]);
268 let flt_idx = m.alloc_id();
269 m.emit(OP_I_ADD, &[b.ty_uint, flt_idx, f5, fx]);
270
271 let flt_ptr = m.alloc_id();
272 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, flt_ptr, p_filter, flt_idx);
273 let flt_val = m.alloc_id();
274 m.emit_load(b.ty_float, flt_val, flt_ptr);
275
276 let prod = m.alloc_id();
278 m.emit(OP_F_MUL, &[b.ty_float, prod, inp_val, flt_val]);
279 let old_acc = m.alloc_id();
280 m.emit_load(b.ty_float, old_acc, var_acc);
281 let new_acc = m.alloc_id();
282 m.emit(OP_F_ADD, &[b.ty_float, new_acc, old_acc, prod]);
283 m.emit_store(var_acc, new_acc);
284
285 m.emit_branch(lbl_skip);
286
287 m.emit_label(lbl_skip);
288 m.emit_branch(lbl_loop_cont);
289
290 m.emit_label(lbl_loop_cont);
292 let flat_inc = m.alloc_id();
293 m.emit(OP_I_ADD, &[b.ty_uint, flat_inc, flat_val, b.c_uint_1]);
294 m.emit_store(var_flat, flat_inc);
295 m.emit_branch(lbl_loop_hdr);
296
297 m.emit_label(lbl_loop_merge);
299
300 let final_acc = m.alloc_id();
301 m.emit_load(b.ty_float, final_acc, var_acc);
302
303 let out_ptr = m.alloc_id();
304 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_output, gid);
305 m.emit_store(out_ptr, final_acc);
306
307 m.emit_branch(label_merge);
308
309 m.emit_label(label_merge);
310 m.emit_return();
311 m.emit_function_end();
312
313 m.finalize()
314}
315
316#[allow(clippy::too_many_arguments)]
333pub fn attention_spirv(
334 batch_heads: u32,
335 seq_q: u32,
336 seq_kv: u32,
337 head_dim: u32,
338 scale: f32,
339 causal: bool,
340) -> Vec<u32> {
341 let mut m = SpvModule::new();
342 let b = emit_preamble(&mut m);
343
344 let main_fn = m.alloc_id();
345 let fn_ty = m.alloc_id();
346 let p_q = m.alloc_id();
347 let p_k = m.alloc_id();
348 let p_v = m.alloc_id();
349 let p_o = m.alloc_id();
350
351 m.emit_type_function(
353 fn_ty,
354 b.ty_void,
355 &[
356 b.ty_ptr_cross_float,
357 b.ty_ptr_cross_float,
358 b.ty_ptr_cross_float,
359 b.ty_ptr_cross_float,
360 ],
361 );
362
363 m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
364 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
365
366 let c_batch_heads = m.alloc_id();
368 m.emit_constant_u32(b.ty_uint, c_batch_heads, batch_heads);
369 let c_seq_q = m.alloc_id();
370 m.emit_constant_u32(b.ty_uint, c_seq_q, seq_q);
371 let c_seq_kv = m.alloc_id();
372 m.emit_constant_u32(b.ty_uint, c_seq_kv, seq_kv);
373 let c_head_dim = m.alloc_id();
374 m.emit_constant_u32(b.ty_uint, c_head_dim, head_dim);
375 let c_scale = m.alloc_id();
376 m.emit_constant_f32(b.ty_float, c_scale, scale);
377 let c_neg_inf = m.alloc_id();
378 m.emit_constant_f32(b.ty_float, c_neg_inf, f32::NEG_INFINITY);
379 let c_skv_hd = m.alloc_id();
381 m.emit_constant_u32(b.ty_uint, c_skv_hd, seq_kv * head_dim);
382
383 let label_entry = m.alloc_id();
385 let label_body = m.alloc_id();
386 let label_merge = m.alloc_id();
387
388 m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
389 m.emit_function_parameter(b.ty_ptr_cross_float, p_q);
390 m.emit_function_parameter(b.ty_ptr_cross_float, p_k);
391 m.emit_function_parameter(b.ty_ptr_cross_float, p_v);
392 m.emit_function_parameter(b.ty_ptr_cross_float, p_o);
393 m.emit_label(label_entry);
394
395 let gid = load_gid_x(&mut m, &b);
396
397 let total = m.alloc_id();
399 m.emit(OP_I_MUL, &[b.ty_uint, total, c_batch_heads, c_seq_q]);
400
401 let cond = m.alloc_id();
402 m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, total]);
403 m.emit_selection_merge(label_merge);
404 m.emit_branch_conditional(cond, label_body, label_merge);
405
406 m.emit_label(label_body);
407
408 let bh = m.alloc_id();
410 m.emit(OP_U_DIV, &[b.ty_uint, bh, gid, c_seq_q]);
411 let sq = m.alloc_id();
412 m.emit(OP_U_MOD, &[b.ty_uint, sq, gid, c_seq_q]);
413
414 let q_base = m.alloc_id();
416 m.emit(OP_I_MUL, &[b.ty_uint, q_base, gid, c_head_dim]);
417 let kv_base = m.alloc_id();
419 m.emit(OP_I_MUL, &[b.ty_uint, kv_base, bh, c_skv_hd]);
420
421 let var_max = m.alloc_id();
423 m.emit_variable(b.ty_ptr_func_float, var_max, STORAGE_CLASS_FUNCTION);
424 m.emit_store(var_max, c_neg_inf);
425
426 emit_score_pass(
428 &mut m, &b, causal, sq, c_seq_kv, c_head_dim, c_scale, q_base, kv_base, p_q, p_k, var_max,
429 true, None, None, p_v, p_o,
430 );
431
432 let final_max = m.alloc_id();
433 m.emit_load(b.ty_float, final_max, var_max);
434
435 let var_sum_exp = m.alloc_id();
437 m.emit_variable(b.ty_ptr_func_float, var_sum_exp, STORAGE_CLASS_FUNCTION);
438 m.emit_store(var_sum_exp, b.c_float_0);
439
440 emit_score_pass(
441 &mut m,
442 &b,
443 causal,
444 sq,
445 c_seq_kv,
446 c_head_dim,
447 c_scale,
448 q_base,
449 kv_base,
450 p_q,
451 p_k,
452 var_sum_exp,
453 false,
454 Some(final_max),
455 Some(p_o),
456 p_v,
457 p_o,
458 );
459
460 let sum_final = m.alloc_id();
462 m.emit_load(b.ty_float, sum_final, var_sum_exp);
463
464 let sum_gt_zero = m.alloc_id();
465 m.emit(
466 OP_F_ORD_GT,
467 &[b.ty_bool, sum_gt_zero, sum_final, b.c_float_0],
468 );
469
470 let lbl_norm = m.alloc_id();
471 let lbl_norm_merge = m.alloc_id();
472 m.emit_selection_merge(lbl_norm_merge);
473 m.emit_branch_conditional(sum_gt_zero, lbl_norm, lbl_norm_merge);
474
475 m.emit_label(lbl_norm);
476
477 let var_d4 = m.alloc_id();
479 m.emit_variable(b.ty_ptr_func_uint, var_d4, STORAGE_CLASS_FUNCTION);
480 m.emit_store(var_d4, b.c_uint_0);
481
482 let lbl_d4_hdr = m.alloc_id();
483 let lbl_d4_body = m.alloc_id();
484 let lbl_d4_cont = m.alloc_id();
485 let lbl_d4_merge = m.alloc_id();
486
487 m.emit_branch(lbl_d4_hdr);
488
489 m.emit_label(lbl_d4_hdr);
490 let d4_val = m.alloc_id();
491 m.emit_load(b.ty_uint, d4_val, var_d4);
492 let d4_cond = m.alloc_id();
493 m.emit(OP_U_LESS_THAN, &[b.ty_bool, d4_cond, d4_val, c_head_dim]);
494 m.emit_loop_merge(lbl_d4_merge, lbl_d4_cont);
495 m.emit_branch_conditional(d4_cond, lbl_d4_body, lbl_d4_merge);
496
497 m.emit_label(lbl_d4_body);
498 let o4_idx = m.alloc_id();
499 m.emit(OP_I_ADD, &[b.ty_uint, o4_idx, q_base, d4_val]);
500 let o4_ptr = m.alloc_id();
501 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, o4_ptr, p_o, o4_idx);
502 let o4_val = m.alloc_id();
503 m.emit_load(b.ty_float, o4_val, o4_ptr);
504 let o4_normed = m.alloc_id();
505 m.emit(OP_F_DIV, &[b.ty_float, o4_normed, o4_val, sum_final]);
506 m.emit_store(o4_ptr, o4_normed);
507
508 m.emit_branch(lbl_d4_cont);
509 m.emit_label(lbl_d4_cont);
510 let d4_inc = m.alloc_id();
511 m.emit(OP_I_ADD, &[b.ty_uint, d4_inc, d4_val, b.c_uint_1]);
512 m.emit_store(var_d4, d4_inc);
513 m.emit_branch(lbl_d4_hdr);
514
515 m.emit_label(lbl_d4_merge);
516
517 m.emit_branch(lbl_norm_merge);
518 m.emit_label(lbl_norm_merge);
519
520 m.emit_branch(label_merge);
521
522 m.emit_label(label_merge);
523 m.emit_return();
524 m.emit_function_end();
525
526 m.finalize()
527}
528
529#[allow(clippy::too_many_arguments)]
535fn emit_score_pass(
536 m: &mut SpvModule,
537 b: &super::spirv::BaseIds,
538 causal: bool,
539 sq: u32,
540 c_seq_kv: u32,
541 c_head_dim: u32,
542 c_scale: u32,
543 q_base: u32,
544 kv_base: u32,
545 p_q: u32,
546 p_k: u32,
547 accum_var: u32,
548 is_max_pass: bool,
549 max_val: Option<u32>,
550 o_buf: Option<u32>,
551 p_v: u32,
552 _p_o_unused: u32,
553) {
554 let var_sk = m.alloc_id();
555 m.emit_variable(b.ty_ptr_func_uint, var_sk, STORAGE_CLASS_FUNCTION);
556 m.emit_store(var_sk, b.c_uint_0);
557
558 let lbl_hdr = m.alloc_id();
559 let lbl_body = m.alloc_id();
560 let lbl_cont = m.alloc_id();
561 let lbl_merge = m.alloc_id();
562
563 m.emit_branch(lbl_hdr);
564
565 m.emit_label(lbl_hdr);
566 let sk_val = m.alloc_id();
567 m.emit_load(b.ty_uint, sk_val, var_sk);
568 let cond = m.alloc_id();
569 m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, sk_val, c_seq_kv]);
570 m.emit_loop_merge(lbl_merge, lbl_cont);
571 m.emit_branch_conditional(cond, lbl_body, lbl_merge);
572
573 m.emit_label(lbl_body);
574
575 let lbl_compute = m.alloc_id();
576 let lbl_skip = m.alloc_id();
577 if causal {
578 let sk_gt_sq = m.alloc_id();
579 m.emit(OP_U_LESS_THAN, &[b.ty_bool, sk_gt_sq, sq, sk_val]);
580 m.emit_selection_merge(lbl_skip);
581 m.emit_branch_conditional(sk_gt_sq, lbl_skip, lbl_compute);
582 } else {
583 m.emit_branch(lbl_compute);
584 }
585
586 m.emit_label(lbl_compute);
587
588 let sk_hd = m.alloc_id();
590 m.emit(OP_I_MUL, &[b.ty_uint, sk_hd, sk_val, c_head_dim]);
591 let k_off = m.alloc_id();
592 m.emit(OP_I_ADD, &[b.ty_uint, k_off, kv_base, sk_hd]);
593
594 let var_d = m.alloc_id();
596 m.emit_variable(b.ty_ptr_func_uint, var_d, STORAGE_CLASS_FUNCTION);
597 m.emit_store(var_d, b.c_uint_0);
598 let var_dot = m.alloc_id();
599 m.emit_variable(b.ty_ptr_func_float, var_dot, STORAGE_CLASS_FUNCTION);
600 m.emit_store(var_dot, b.c_float_0);
601
602 let lbl_d_hdr = m.alloc_id();
603 let lbl_d_body = m.alloc_id();
604 let lbl_d_cont = m.alloc_id();
605 let lbl_d_merge = m.alloc_id();
606
607 m.emit_branch(lbl_d_hdr);
608
609 m.emit_label(lbl_d_hdr);
610 let d_val = m.alloc_id();
611 m.emit_load(b.ty_uint, d_val, var_d);
612 let d_cond = m.alloc_id();
613 m.emit(OP_U_LESS_THAN, &[b.ty_bool, d_cond, d_val, c_head_dim]);
614 m.emit_loop_merge(lbl_d_merge, lbl_d_cont);
615 m.emit_branch_conditional(d_cond, lbl_d_body, lbl_d_merge);
616
617 m.emit_label(lbl_d_body);
618 let q_idx = m.alloc_id();
619 m.emit(OP_I_ADD, &[b.ty_uint, q_idx, q_base, d_val]);
620 let q_ptr = m.alloc_id();
621 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, q_ptr, p_q, q_idx);
622 let q_val = m.alloc_id();
623 m.emit_load(b.ty_float, q_val, q_ptr);
624 let k_idx = m.alloc_id();
625 m.emit(OP_I_ADD, &[b.ty_uint, k_idx, k_off, d_val]);
626 let k_ptr = m.alloc_id();
627 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, k_ptr, p_k, k_idx);
628 let k_val = m.alloc_id();
629 m.emit_load(b.ty_float, k_val, k_ptr);
630
631 let qk_prod = m.alloc_id();
632 m.emit(OP_F_MUL, &[b.ty_float, qk_prod, q_val, k_val]);
633 let old_dot = m.alloc_id();
634 m.emit_load(b.ty_float, old_dot, var_dot);
635 let new_dot = m.alloc_id();
636 m.emit(OP_F_ADD, &[b.ty_float, new_dot, old_dot, qk_prod]);
637 m.emit_store(var_dot, new_dot);
638
639 m.emit_branch(lbl_d_cont);
640 m.emit_label(lbl_d_cont);
641 let d_inc = m.alloc_id();
642 m.emit(OP_I_ADD, &[b.ty_uint, d_inc, d_val, b.c_uint_1]);
643 m.emit_store(var_d, d_inc);
644 m.emit_branch(lbl_d_hdr);
645
646 m.emit_label(lbl_d_merge);
647
648 let dot_final = m.alloc_id();
650 m.emit_load(b.ty_float, dot_final, var_dot);
651 let score = m.alloc_id();
652 m.emit(OP_F_MUL, &[b.ty_float, score, dot_final, c_scale]);
653
654 if is_max_pass {
655 let old_acc = m.alloc_id();
657 m.emit_load(b.ty_float, old_acc, accum_var);
658 let new_acc = m.alloc_id();
659 m.emit_opencl_ext(
660 b.opencl_ext,
661 b.ty_float,
662 new_acc,
663 OPENCL_FMAX,
664 &[old_acc, score],
665 );
666 m.emit_store(accum_var, new_acc);
667 } else {
668 let max_id = max_val.unwrap_or(b.c_float_0);
670 let score_shifted = m.alloc_id();
671 m.emit(OP_F_SUB, &[b.ty_float, score_shifted, score, max_id]);
672 let w = m.alloc_id();
673 m.emit_opencl_ext(b.opencl_ext, b.ty_float, w, OPENCL_EXP, &[score_shifted]);
674
675 let old_sum = m.alloc_id();
677 m.emit_load(b.ty_float, old_sum, accum_var);
678 let new_sum = m.alloc_id();
679 m.emit(OP_F_ADD, &[b.ty_float, new_sum, old_sum, w]);
680 m.emit_store(accum_var, new_sum);
681
682 if let Some(o_buf_id) = o_buf {
684 let v_off = m.alloc_id();
685 m.emit(OP_I_ADD, &[b.ty_uint, v_off, kv_base, sk_hd]);
686
687 let var_d3 = m.alloc_id();
688 m.emit_variable(b.ty_ptr_func_uint, var_d3, STORAGE_CLASS_FUNCTION);
689 m.emit_store(var_d3, b.c_uint_0);
690
691 let lbl_d3_hdr = m.alloc_id();
692 let lbl_d3_body = m.alloc_id();
693 let lbl_d3_cont = m.alloc_id();
694 let lbl_d3_merge = m.alloc_id();
695
696 m.emit_branch(lbl_d3_hdr);
697
698 m.emit_label(lbl_d3_hdr);
699 let d3_val = m.alloc_id();
700 m.emit_load(b.ty_uint, d3_val, var_d3);
701 let d3_cond = m.alloc_id();
702 m.emit(OP_U_LESS_THAN, &[b.ty_bool, d3_cond, d3_val, c_head_dim]);
703 m.emit_loop_merge(lbl_d3_merge, lbl_d3_cont);
704 m.emit_branch_conditional(d3_cond, lbl_d3_body, lbl_d3_merge);
705
706 m.emit_label(lbl_d3_body);
707 let v_idx = m.alloc_id();
708 m.emit(OP_I_ADD, &[b.ty_uint, v_idx, v_off, d3_val]);
709 let v_ptr = m.alloc_id();
710 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, v_ptr, p_v, v_idx);
711 let v_val = m.alloc_id();
712 m.emit_load(b.ty_float, v_val, v_ptr);
713 let wv = m.alloc_id();
714 m.emit(OP_F_MUL, &[b.ty_float, wv, w, v_val]);
715
716 let o_idx = m.alloc_id();
717 m.emit(OP_I_ADD, &[b.ty_uint, o_idx, q_base, d3_val]);
718 let o_ptr = m.alloc_id();
719 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, o_ptr, o_buf_id, o_idx);
720 let o_old = m.alloc_id();
721 m.emit_load(b.ty_float, o_old, o_ptr);
722 let o_new = m.alloc_id();
723 m.emit(OP_F_ADD, &[b.ty_float, o_new, o_old, wv]);
724 m.emit_store(o_ptr, o_new);
725
726 m.emit_branch(lbl_d3_cont);
727 m.emit_label(lbl_d3_cont);
728 let d3_inc = m.alloc_id();
729 m.emit(OP_I_ADD, &[b.ty_uint, d3_inc, d3_val, b.c_uint_1]);
730 m.emit_store(var_d3, d3_inc);
731 m.emit_branch(lbl_d3_hdr);
732
733 m.emit_label(lbl_d3_merge);
734 }
735 }
736
737 m.emit_branch(lbl_skip);
738 m.emit_label(lbl_skip);
739 m.emit_branch(lbl_cont);
740
741 m.emit_label(lbl_cont);
742 let sk_inc = m.alloc_id();
743 m.emit(OP_I_ADD, &[b.ty_uint, sk_inc, sk_val, b.c_uint_1]);
744 m.emit_store(var_sk, sk_inc);
745 m.emit_branch(lbl_hdr);
746
747 m.emit_label(lbl_merge);
748}
749
750#[cfg(test)]
753mod tests {
754 use super::*;
755 use crate::spirv::SPIRV_MAGIC;
756
757 fn check_valid_spirv(words: &[u32]) {
758 assert!(words.len() >= 5, "too short for SPIR-V header");
759 assert_eq!(words[0], SPIRV_MAGIC, "bad magic");
760 assert!(words[3] > 0, "ID bound must be > 0");
761 assert_eq!(words[4], 0, "schema must be 0");
762 }
763
764 #[test]
765 fn conv2d_spirv_valid() {
766 let words = conv2d_spirv(1, 3, 8, 8, 16, 3, 3, 6, 6, 1, 1, 0, 0);
767 check_valid_spirv(&words);
768 }
769
770 #[test]
771 fn conv2d_spirv_with_padding() {
772 let words = conv2d_spirv(2, 1, 5, 5, 4, 3, 3, 5, 5, 1, 1, 1, 1);
773 check_valid_spirv(&words);
774 }
775
776 #[test]
777 fn conv2d_spirv_1x1() {
778 let words = conv2d_spirv(1, 3, 4, 4, 8, 1, 1, 4, 4, 1, 1, 0, 0);
779 check_valid_spirv(&words);
780 }
781
782 #[test]
783 fn attention_spirv_valid() {
784 let words = attention_spirv(2, 4, 4, 8, 0.125, false);
785 check_valid_spirv(&words);
786 }
787
788 #[test]
789 fn attention_spirv_causal() {
790 let words = attention_spirv(1, 8, 8, 16, 0.25, true);
791 check_valid_spirv(&words);
792 }
793
794 #[test]
795 fn attention_spirv_magic_number() {
796 let words = attention_spirv(1, 4, 4, 8, 0.125, false);
797 assert_eq!(words[0], 0x07230203);
798 }
799
800 #[test]
801 fn conv2d_spirv_magic_number() {
802 let words = conv2d_spirv(1, 1, 4, 4, 1, 1, 1, 4, 4, 1, 1, 0, 0);
803 assert_eq!(words[0], 0x07230203);
804 }
805}