1use crate::spirv::{
16 EXECUTION_MODEL_KERNEL, FUNCTION_CONTROL_NONE, OP_COMPOSITE_EXTRACT, OP_CONTROL_BARRIER,
17 OP_F_ADD, OP_F_MUL, OP_GROUP_NON_UNIFORM_FADD, OP_GROUP_NON_UNIFORM_SHUFFLE, OP_I_ADD,
18 OP_I_MUL, OP_PHI, OP_TYPE_ARRAY, OP_U_LESS_THAN, STORAGE_CLASS_FUNCTION, SpvModule,
19 WORKGROUP_SIZE,
20};
21
22const CAPABILITY_ADDRESSES: u32 = 4;
26const CAPABILITY_KERNEL: u32 = 6;
27const CAPABILITY_GROUP_NON_UNIFORM: u32 = 61;
28const CAPABILITY_GROUP_NON_UNIFORM_ARITHMETIC: u32 = 63;
29const CAPABILITY_GROUP_NON_UNIFORM_SHUFFLE: u32 = 65;
30
31const ADDRESSING_MODEL_PHYSICAL64: u32 = 2;
33const MEMORY_MODEL_OPENCL: u32 = 2;
34
35const DECORATION_BUILTIN: u32 = 11;
37
38const BUILTIN_GLOBAL_INVOCATION_ID: u32 = 28;
40const BUILTIN_NUM_SUBGROUPS: u32 = 38;
41const BUILTIN_SUBGROUP_ID: u32 = 40;
42const BUILTIN_SUBGROUP_LOCAL_INVOCATION_ID: u32 = 41;
43
44const STORAGE_CLASS_INPUT: u32 = 1;
46const STORAGE_CLASS_WORKGROUP: u32 = 4;
47const STORAGE_CLASS_CROSS_WORKGROUP: u32 = 5;
48
49const SCOPE_WORKGROUP: u32 = 2;
51const SCOPE_SUBGROUP: u32 = 3;
52
53const MEMORY_SEMANTICS_WORKGROUP_MEMORY: u32 = 0x100;
55
56const GROUP_OPERATION_REDUCE: u32 = 0;
58const GROUP_OPERATION_INCLUSIVE_SCAN: u32 = 1;
59
60const OP_I_EQUAL: u32 = 170;
62const OP_SELECT: u32 = 169;
63
64const MAX_SUBGROUPS: u32 = 32;
66
67pub fn reduction_subgroup_spirv() -> Vec<u32> {
80 let mut m = SpvModule::new();
81
82 m.emit_capability(CAPABILITY_KERNEL);
84 m.emit_capability(CAPABILITY_ADDRESSES);
85 m.emit_capability(CAPABILITY_GROUP_NON_UNIFORM);
86 m.emit_capability(CAPABILITY_GROUP_NON_UNIFORM_ARITHMETIC);
87
88 let opencl_ext = m.alloc_id();
90 m.emit_ext_inst_import(opencl_ext, "OpenCL.std");
91
92 m.emit_memory_model(ADDRESSING_MODEL_PHYSICAL64, MEMORY_MODEL_OPENCL);
94
95 let ty_void = m.alloc_id();
97 let ty_bool = m.alloc_id();
98 let ty_uint = m.alloc_id();
99 let ty_float = m.alloc_id();
100 let ty_v3uint = m.alloc_id();
101 let ty_ptr_input_v3uint = m.alloc_id();
102 let ty_ptr_cross_float = m.alloc_id();
103 let ty_ptr_func_float = m.alloc_id();
104 let ty_ptr_wg_float = m.alloc_id();
105 let ty_ptr_input_uint = m.alloc_id();
106 let ty_arr_float = m.alloc_id();
107 let ty_ptr_wg_arr = m.alloc_id();
108
109 let c_uint_0 = m.alloc_id();
111 let c_uint_max_sg = m.alloc_id();
112 let c_float_0 = m.alloc_id();
113 let c_scope_sg = m.alloc_id();
114 let c_scope_wg = m.alloc_id();
115 let c_mem_sem = m.alloc_id();
116
117 let var_gid = m.alloc_id();
119 let var_sg_id = m.alloc_id();
120 let var_sg_lid = m.alloc_id();
121 let var_num_sg = m.alloc_id();
122 let var_scratch = m.alloc_id();
123
124 let fn_ty = m.alloc_id();
126 let main_fn = m.alloc_id();
127 let p_input = m.alloc_id();
128 let p_output = m.alloc_id();
129 let p_count = m.alloc_id();
130
131 m.emit_entry_point(
133 EXECUTION_MODEL_KERNEL,
134 main_fn,
135 "reduction_subgroup",
136 &[var_gid, var_sg_id, var_sg_lid, var_num_sg],
137 );
138 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
139
140 m.emit_decorate(var_gid, DECORATION_BUILTIN, &[BUILTIN_GLOBAL_INVOCATION_ID]);
142 m.emit_decorate(var_sg_id, DECORATION_BUILTIN, &[BUILTIN_SUBGROUP_ID]);
143 m.emit_decorate(
144 var_sg_lid,
145 DECORATION_BUILTIN,
146 &[BUILTIN_SUBGROUP_LOCAL_INVOCATION_ID],
147 );
148 m.emit_decorate(var_num_sg, DECORATION_BUILTIN, &[BUILTIN_NUM_SUBGROUPS]);
149
150 m.emit_type_void(ty_void);
152 m.emit_type_bool(ty_bool);
153 m.emit_type_int(ty_uint, 32, 0);
154 m.emit_type_float(ty_float, 32);
155 m.emit_type_vector(ty_v3uint, ty_uint, 3);
156 m.emit_type_pointer(ty_ptr_input_v3uint, STORAGE_CLASS_INPUT, ty_v3uint);
157 m.emit_type_pointer(ty_ptr_cross_float, STORAGE_CLASS_CROSS_WORKGROUP, ty_float);
158 m.emit_type_pointer(ty_ptr_func_float, STORAGE_CLASS_FUNCTION, ty_float);
159 m.emit_type_pointer(ty_ptr_wg_float, STORAGE_CLASS_WORKGROUP, ty_float);
160 m.emit_type_pointer(ty_ptr_input_uint, STORAGE_CLASS_INPUT, ty_uint);
161 m.emit_type_function(
162 fn_ty,
163 ty_void,
164 &[ty_ptr_cross_float, ty_ptr_cross_float, ty_uint],
165 );
166
167 m.emit_constant_u32(ty_uint, c_uint_0, 0);
169 m.emit_constant_u32(ty_uint, c_uint_max_sg, MAX_SUBGROUPS);
170 m.emit_constant_f32(ty_float, c_float_0, 0.0);
171 m.emit_constant_u32(ty_uint, c_scope_sg, SCOPE_SUBGROUP);
172 m.emit_constant_u32(ty_uint, c_scope_wg, SCOPE_WORKGROUP);
173 m.emit_constant_u32(ty_uint, c_mem_sem, MEMORY_SEMANTICS_WORKGROUP_MEMORY);
174
175 m.emit(OP_TYPE_ARRAY, &[ty_arr_float, ty_float, c_uint_max_sg]);
177 m.emit_type_pointer(ty_ptr_wg_arr, STORAGE_CLASS_WORKGROUP, ty_arr_float);
178
179 m.emit_variable(ty_ptr_input_v3uint, var_gid, STORAGE_CLASS_INPUT);
181 m.emit_variable(ty_ptr_input_uint, var_sg_id, STORAGE_CLASS_INPUT);
182 m.emit_variable(ty_ptr_input_uint, var_sg_lid, STORAGE_CLASS_INPUT);
183 m.emit_variable(ty_ptr_input_uint, var_num_sg, STORAGE_CLASS_INPUT);
184 m.emit_variable(ty_ptr_wg_arr, var_scratch, STORAGE_CLASS_WORKGROUP);
185
186 let label_entry = m.alloc_id();
188 let label_bounds_body = m.alloc_id();
189 let label_bounds_merge = m.alloc_id();
190 let label_leader = m.alloc_id();
191 let label_after_leader = m.alloc_id();
192 let label_sg0 = m.alloc_id();
193 let label_after_sg0 = m.alloc_id();
194
195 m.emit_function(ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
197 m.emit_function_parameter(ty_ptr_cross_float, p_input);
198 m.emit_function_parameter(ty_ptr_cross_float, p_output);
199 m.emit_function_parameter(ty_uint, p_count);
200 m.emit_label(label_entry);
201
202 let gid_val = m.alloc_id();
204 m.emit_load(ty_v3uint, gid_val, var_gid);
205 let gid_x = m.alloc_id();
206 m.emit(OP_COMPOSITE_EXTRACT, &[ty_uint, gid_x, gid_val, 0]);
207
208 let sg_id = m.alloc_id();
210 m.emit_load(ty_uint, sg_id, var_sg_id);
211 let sg_lid = m.alloc_id();
212 m.emit_load(ty_uint, sg_lid, var_sg_lid);
213 let num_sg = m.alloc_id();
214 m.emit_load(ty_uint, num_sg, var_num_sg);
215
216 let cond_bounds = m.alloc_id();
218 m.emit(OP_U_LESS_THAN, &[ty_bool, cond_bounds, gid_x, p_count]);
219
220 m.emit_selection_merge(label_bounds_merge);
222 m.emit_branch_conditional(cond_bounds, label_bounds_body, label_bounds_merge);
223
224 m.emit_label(label_bounds_body);
225 let inp_ptr = m.alloc_id();
226 m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, inp_ptr, p_input, gid_x);
227 let inp_val = m.alloc_id();
228 m.emit_load(ty_float, inp_val, inp_ptr);
229 m.emit_branch(label_bounds_merge);
230
231 m.emit_label(label_bounds_merge);
232 let val = m.alloc_id();
234 m.emit(
235 OP_PHI,
236 &[
237 ty_float,
238 val,
239 inp_val,
240 label_bounds_body,
241 c_float_0,
242 label_entry,
243 ],
244 );
245
246 let sg_sum = m.alloc_id();
248 m.emit(
249 OP_GROUP_NON_UNIFORM_FADD,
250 &[ty_float, sg_sum, c_scope_sg, GROUP_OPERATION_REDUCE, val],
251 );
252
253 let is_leader_eq = m.alloc_id();
255 m.emit(OP_I_EQUAL, &[ty_bool, is_leader_eq, sg_lid, c_uint_0]);
256
257 m.emit_selection_merge(label_after_leader);
258 m.emit_branch_conditional(is_leader_eq, label_leader, label_after_leader);
259
260 m.emit_label(label_leader);
261 let scratch_ptr = m.alloc_id();
262 m.emit_in_bounds_ptr_access_chain(ty_ptr_wg_float, scratch_ptr, var_scratch, sg_id);
263 m.emit_store(scratch_ptr, sg_sum);
264 m.emit_branch(label_after_leader);
265
266 m.emit_label(label_after_leader);
267
268 m.emit(OP_CONTROL_BARRIER, &[c_scope_wg, c_scope_wg, c_mem_sem]);
270
271 let is_sg0 = m.alloc_id();
273 m.emit(OP_I_EQUAL, &[ty_bool, is_sg0, sg_id, c_uint_0]);
274
275 let lid_lt_nsg = m.alloc_id();
277 m.emit(OP_U_LESS_THAN, &[ty_bool, lid_lt_nsg, sg_lid, num_sg]);
278
279 m.emit_selection_merge(label_after_sg0);
280 m.emit_branch_conditional(is_sg0, label_sg0, label_after_sg0);
281
282 m.emit_label(label_sg0);
283
284 let s_ptr = m.alloc_id();
286 m.emit_in_bounds_ptr_access_chain(ty_ptr_wg_float, s_ptr, var_scratch, sg_lid);
287 let partial = m.alloc_id();
288 m.emit_load(ty_float, partial, s_ptr);
289
290 let safe_partial = m.alloc_id();
292 m.emit(
293 OP_SELECT,
294 &[ty_float, safe_partial, lid_lt_nsg, partial, c_float_0],
295 );
296
297 let final_sum = m.alloc_id();
299 m.emit(
300 OP_GROUP_NON_UNIFORM_FADD,
301 &[
302 ty_float,
303 final_sum,
304 c_scope_sg,
305 GROUP_OPERATION_REDUCE,
306 safe_partial,
307 ],
308 );
309
310 let is_lane0 = m.alloc_id();
312 m.emit(OP_I_EQUAL, &[ty_bool, is_lane0, sg_lid, c_uint_0]);
313 let label_write = m.alloc_id();
314 let label_after_write = m.alloc_id();
315 m.emit_selection_merge(label_after_write);
316 m.emit_branch_conditional(is_lane0, label_write, label_after_write);
317
318 m.emit_label(label_write);
319 let out_ptr = m.alloc_id();
320 m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, out_ptr, p_output, c_uint_0);
321 m.emit_store(out_ptr, final_sum);
322 m.emit_branch(label_after_write);
323
324 m.emit_label(label_after_write);
325 m.emit_branch(label_after_sg0);
326
327 m.emit_label(label_after_sg0);
328 m.emit_return();
329 m.emit_function_end();
330
331 m.finalize()
332}
333
334pub fn scan_subgroup_spirv() -> Vec<u32> {
345 let mut m = SpvModule::new();
346
347 m.emit_capability(CAPABILITY_KERNEL);
349 m.emit_capability(CAPABILITY_ADDRESSES);
350 m.emit_capability(CAPABILITY_GROUP_NON_UNIFORM);
351 m.emit_capability(CAPABILITY_GROUP_NON_UNIFORM_ARITHMETIC);
352
353 let opencl_ext = m.alloc_id();
355 m.emit_ext_inst_import(opencl_ext, "OpenCL.std");
356
357 m.emit_memory_model(ADDRESSING_MODEL_PHYSICAL64, MEMORY_MODEL_OPENCL);
359
360 let ty_void = m.alloc_id();
362 let ty_bool = m.alloc_id();
363 let ty_uint = m.alloc_id();
364 let ty_float = m.alloc_id();
365 let ty_v3uint = m.alloc_id();
366 let ty_ptr_input_v3uint = m.alloc_id();
367 let ty_ptr_cross_float = m.alloc_id();
368 let ty_ptr_input_uint = m.alloc_id();
369 let ty_arr_float = m.alloc_id();
370 let ty_ptr_wg_float = m.alloc_id();
371 let ty_ptr_wg_arr = m.alloc_id();
372
373 let c_uint_0 = m.alloc_id();
375 let c_uint_1 = m.alloc_id();
376 let c_uint_max_sg = m.alloc_id();
377 let c_float_0 = m.alloc_id();
378 let c_scope_sg = m.alloc_id();
379 let c_scope_wg = m.alloc_id();
380 let c_mem_sem = m.alloc_id();
381
382 let var_gid = m.alloc_id();
384 let var_sg_id = m.alloc_id();
385 let var_sg_lid = m.alloc_id();
386 let var_num_sg = m.alloc_id();
387 let var_scratch = m.alloc_id();
388
389 let fn_ty = m.alloc_id();
391 let main_fn = m.alloc_id();
392 let p_input = m.alloc_id();
393 let p_output = m.alloc_id();
394 let p_count = m.alloc_id();
395
396 m.emit_entry_point(
398 EXECUTION_MODEL_KERNEL,
399 main_fn,
400 "scan_subgroup",
401 &[var_gid, var_sg_id, var_sg_lid, var_num_sg],
402 );
403 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
404
405 m.emit_decorate(var_gid, DECORATION_BUILTIN, &[BUILTIN_GLOBAL_INVOCATION_ID]);
407 m.emit_decorate(var_sg_id, DECORATION_BUILTIN, &[BUILTIN_SUBGROUP_ID]);
408 m.emit_decorate(
409 var_sg_lid,
410 DECORATION_BUILTIN,
411 &[BUILTIN_SUBGROUP_LOCAL_INVOCATION_ID],
412 );
413 m.emit_decorate(var_num_sg, DECORATION_BUILTIN, &[BUILTIN_NUM_SUBGROUPS]);
414
415 m.emit_type_void(ty_void);
417 m.emit_type_bool(ty_bool);
418 m.emit_type_int(ty_uint, 32, 0);
419 m.emit_type_float(ty_float, 32);
420 m.emit_type_vector(ty_v3uint, ty_uint, 3);
421 m.emit_type_pointer(ty_ptr_input_v3uint, STORAGE_CLASS_INPUT, ty_v3uint);
422 m.emit_type_pointer(ty_ptr_cross_float, STORAGE_CLASS_CROSS_WORKGROUP, ty_float);
423 m.emit_type_pointer(ty_ptr_input_uint, STORAGE_CLASS_INPUT, ty_uint);
424 m.emit_type_pointer(ty_ptr_wg_float, STORAGE_CLASS_WORKGROUP, ty_float);
425 m.emit_type_function(
426 fn_ty,
427 ty_void,
428 &[ty_ptr_cross_float, ty_ptr_cross_float, ty_uint],
429 );
430
431 m.emit_constant_u32(ty_uint, c_uint_0, 0);
433 m.emit_constant_u32(ty_uint, c_uint_1, 1);
434 m.emit_constant_u32(ty_uint, c_uint_max_sg, MAX_SUBGROUPS);
435 m.emit_constant_f32(ty_float, c_float_0, 0.0);
436 m.emit_constant_u32(ty_uint, c_scope_sg, SCOPE_SUBGROUP);
437 m.emit_constant_u32(ty_uint, c_scope_wg, SCOPE_WORKGROUP);
438 m.emit_constant_u32(ty_uint, c_mem_sem, MEMORY_SEMANTICS_WORKGROUP_MEMORY);
439
440 m.emit(OP_TYPE_ARRAY, &[ty_arr_float, ty_float, c_uint_max_sg]);
442 m.emit_type_pointer(ty_ptr_wg_arr, STORAGE_CLASS_WORKGROUP, ty_arr_float);
443
444 m.emit_variable(ty_ptr_input_v3uint, var_gid, STORAGE_CLASS_INPUT);
446 m.emit_variable(ty_ptr_input_uint, var_sg_id, STORAGE_CLASS_INPUT);
447 m.emit_variable(ty_ptr_input_uint, var_sg_lid, STORAGE_CLASS_INPUT);
448 m.emit_variable(ty_ptr_input_uint, var_num_sg, STORAGE_CLASS_INPUT);
449 m.emit_variable(ty_ptr_wg_arr, var_scratch, STORAGE_CLASS_WORKGROUP);
450
451 let label_entry = m.alloc_id();
453 let label_bounds_body = m.alloc_id();
454 let label_bounds_merge = m.alloc_id();
455 let label_leader = m.alloc_id();
456 let label_after_leader = m.alloc_id();
457 let label_add_prefix = m.alloc_id();
458 let label_after_add = m.alloc_id();
459 let label_write = m.alloc_id();
460 let label_end = m.alloc_id();
461
462 m.emit_function(ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
464 m.emit_function_parameter(ty_ptr_cross_float, p_input);
465 m.emit_function_parameter(ty_ptr_cross_float, p_output);
466 m.emit_function_parameter(ty_uint, p_count);
467 m.emit_label(label_entry);
468
469 let gid_val = m.alloc_id();
471 m.emit_load(ty_v3uint, gid_val, var_gid);
472 let gid_x = m.alloc_id();
473 m.emit(OP_COMPOSITE_EXTRACT, &[ty_uint, gid_x, gid_val, 0]);
474
475 let sg_id = m.alloc_id();
477 m.emit_load(ty_uint, sg_id, var_sg_id);
478 let sg_lid = m.alloc_id();
479 m.emit_load(ty_uint, sg_lid, var_sg_lid);
480
481 let in_bounds = m.alloc_id();
483 m.emit(OP_U_LESS_THAN, &[ty_bool, in_bounds, gid_x, p_count]);
484 m.emit_selection_merge(label_bounds_merge);
485 m.emit_branch_conditional(in_bounds, label_bounds_body, label_bounds_merge);
486
487 m.emit_label(label_bounds_body);
488 let inp_ptr = m.alloc_id();
489 m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, inp_ptr, p_input, gid_x);
490 let inp_val = m.alloc_id();
491 m.emit_load(ty_float, inp_val, inp_ptr);
492 m.emit_branch(label_bounds_merge);
493
494 m.emit_label(label_bounds_merge);
495 let val = m.alloc_id();
496 m.emit(
497 OP_PHI,
498 &[
499 ty_float,
500 val,
501 inp_val,
502 label_bounds_body,
503 c_float_0,
504 label_entry,
505 ],
506 );
507
508 let sg_scan = m.alloc_id();
510 m.emit(
511 OP_GROUP_NON_UNIFORM_FADD,
512 &[
513 ty_float,
514 sg_scan,
515 c_scope_sg,
516 GROUP_OPERATION_INCLUSIVE_SCAN,
517 val,
518 ],
519 );
520
521 let sg_total = m.alloc_id();
523 m.emit(
524 OP_GROUP_NON_UNIFORM_FADD,
525 &[ty_float, sg_total, c_scope_sg, GROUP_OPERATION_REDUCE, val],
526 );
527
528 let is_leader = m.alloc_id();
530 m.emit(OP_I_EQUAL, &[ty_bool, is_leader, sg_lid, c_uint_0]);
531
532 m.emit_selection_merge(label_after_leader);
533 m.emit_branch_conditional(is_leader, label_leader, label_after_leader);
534
535 m.emit_label(label_leader);
536 let scratch_ptr = m.alloc_id();
537 m.emit_in_bounds_ptr_access_chain(ty_ptr_wg_float, scratch_ptr, var_scratch, sg_id);
538 m.emit_store(scratch_ptr, sg_total);
539 m.emit_branch(label_after_leader);
540
541 m.emit_label(label_after_leader);
542
543 m.emit(OP_CONTROL_BARRIER, &[c_scope_wg, c_scope_wg, c_mem_sem]);
545
546 let has_prefix = m.alloc_id();
549 m.emit(OP_U_LESS_THAN, &[ty_bool, has_prefix, c_uint_0, sg_id]); m.emit_selection_merge(label_after_add);
552 m.emit_branch_conditional(has_prefix, label_add_prefix, label_after_add);
553
554 m.emit_label(label_add_prefix);
555
556 let var_j = m.alloc_id();
558 let ty_ptr_func_uint = m.alloc_id();
559 m.emit_type_pointer(ty_ptr_func_uint, STORAGE_CLASS_FUNCTION, ty_uint);
560 let var_prefix_acc = m.alloc_id();
561 let ty_ptr_func_float = m.alloc_id();
562 m.emit_type_pointer(ty_ptr_func_float, STORAGE_CLASS_FUNCTION, ty_float);
563 m.emit_variable(ty_ptr_func_uint, var_j, STORAGE_CLASS_FUNCTION);
564 m.emit_variable(ty_ptr_func_float, var_prefix_acc, STORAGE_CLASS_FUNCTION);
565 m.emit_store(var_j, c_uint_0);
566 m.emit_store(var_prefix_acc, c_float_0);
567
568 let lbl_loop_hdr = m.alloc_id();
569 let lbl_loop_body = m.alloc_id();
570 let lbl_loop_cont = m.alloc_id();
571 let lbl_loop_merge = m.alloc_id();
572
573 m.emit_branch(lbl_loop_hdr);
574 m.emit_label(lbl_loop_hdr);
575 let j_val = m.alloc_id();
576 m.emit_load(ty_uint, j_val, var_j);
577 let loop_cond = m.alloc_id();
578 m.emit(OP_U_LESS_THAN, &[ty_bool, loop_cond, j_val, sg_id]);
579 m.emit_loop_merge(lbl_loop_merge, lbl_loop_cont);
580 m.emit_branch_conditional(loop_cond, lbl_loop_body, lbl_loop_merge);
581
582 m.emit_label(lbl_loop_body);
583 let s_ptr = m.alloc_id();
584 m.emit_in_bounds_ptr_access_chain(ty_ptr_wg_float, s_ptr, var_scratch, j_val);
585 let s_val = m.alloc_id();
586 m.emit_load(ty_float, s_val, s_ptr);
587 let old_prefix = m.alloc_id();
588 m.emit_load(ty_float, old_prefix, var_prefix_acc);
589 let new_prefix = m.alloc_id();
590 m.emit(OP_F_ADD, &[ty_float, new_prefix, old_prefix, s_val]);
591 m.emit_store(var_prefix_acc, new_prefix);
592 m.emit_branch(lbl_loop_cont);
593
594 m.emit_label(lbl_loop_cont);
595 let j_inc = m.alloc_id();
596 m.emit(OP_I_ADD, &[ty_uint, j_inc, j_val, c_uint_1]);
597 m.emit_store(var_j, j_inc);
598 m.emit_branch(lbl_loop_hdr);
599
600 m.emit_label(lbl_loop_merge);
601 let prefix_val = m.alloc_id();
602 m.emit_load(ty_float, prefix_val, var_prefix_acc);
603 m.emit_branch(label_after_add);
604
605 m.emit_label(label_after_add);
606 let prefix = m.alloc_id();
608 m.emit(
609 OP_PHI,
610 &[
611 ty_float,
612 prefix,
613 prefix_val,
614 lbl_loop_merge,
615 c_float_0,
616 label_after_leader,
617 ],
618 );
619
620 let final_val = m.alloc_id();
622 m.emit(OP_F_ADD, &[ty_float, final_val, sg_scan, prefix]);
623
624 m.emit_selection_merge(label_end);
626 m.emit_branch_conditional(in_bounds, label_write, label_end);
627
628 m.emit_label(label_write);
629 let out_ptr = m.alloc_id();
630 m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, out_ptr, p_output, gid_x);
631 m.emit_store(out_ptr, final_val);
632 m.emit_branch(label_end);
633
634 m.emit_label(label_end);
635 m.emit_return();
636 m.emit_function_end();
637
638 m.finalize()
639}
640
641pub fn gemm_subgroup_spirv() -> Vec<u32> {
658 let mut m = SpvModule::new();
659
660 m.emit_capability(CAPABILITY_KERNEL);
662 m.emit_capability(CAPABILITY_ADDRESSES);
663 m.emit_capability(CAPABILITY_GROUP_NON_UNIFORM);
664 m.emit_capability(CAPABILITY_GROUP_NON_UNIFORM_SHUFFLE);
665
666 let opencl_ext = m.alloc_id();
668 m.emit_ext_inst_import(opencl_ext, "OpenCL.std");
669
670 m.emit_memory_model(ADDRESSING_MODEL_PHYSICAL64, MEMORY_MODEL_OPENCL);
672
673 let ty_void = m.alloc_id();
675 let ty_bool = m.alloc_id();
676 let ty_uint = m.alloc_id();
677 let ty_float = m.alloc_id();
678 let ty_v3uint = m.alloc_id();
679 let ty_ptr_input_v3uint = m.alloc_id();
680 let ty_ptr_cross_float = m.alloc_id();
681 let ty_ptr_func_float = m.alloc_id();
682 let ty_ptr_func_uint = m.alloc_id();
683 let ty_ptr_input_uint = m.alloc_id();
684
685 let c_uint_0 = m.alloc_id();
687 let c_uint_1 = m.alloc_id();
688 let c_float_0 = m.alloc_id();
689 let c_scope_sg = m.alloc_id();
690
691 let var_gid = m.alloc_id();
693 let var_sg_lid = m.alloc_id();
694
695 let fn_ty = m.alloc_id();
697 let main_fn = m.alloc_id();
698 let p_a = m.alloc_id();
699 let p_b = m.alloc_id();
700 let p_c = m.alloc_id();
701 let p_m = m.alloc_id();
702 let p_n = m.alloc_id();
703 let p_k = m.alloc_id();
704 let p_alpha = m.alloc_id();
705 let p_beta = m.alloc_id();
706
707 m.emit_entry_point(
709 EXECUTION_MODEL_KERNEL,
710 main_fn,
711 "gemm_subgroup",
712 &[var_gid, var_sg_lid],
713 );
714 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
715
716 m.emit_decorate(var_gid, DECORATION_BUILTIN, &[BUILTIN_GLOBAL_INVOCATION_ID]);
718 m.emit_decorate(
719 var_sg_lid,
720 DECORATION_BUILTIN,
721 &[BUILTIN_SUBGROUP_LOCAL_INVOCATION_ID],
722 );
723
724 m.emit_type_void(ty_void);
726 m.emit_type_bool(ty_bool);
727 m.emit_type_int(ty_uint, 32, 0);
728 m.emit_type_float(ty_float, 32);
729 m.emit_type_vector(ty_v3uint, ty_uint, 3);
730 m.emit_type_pointer(ty_ptr_input_v3uint, STORAGE_CLASS_INPUT, ty_v3uint);
731 m.emit_type_pointer(ty_ptr_cross_float, STORAGE_CLASS_CROSS_WORKGROUP, ty_float);
732 m.emit_type_pointer(ty_ptr_func_float, STORAGE_CLASS_FUNCTION, ty_float);
733 m.emit_type_pointer(ty_ptr_func_uint, STORAGE_CLASS_FUNCTION, ty_uint);
734 m.emit_type_pointer(ty_ptr_input_uint, STORAGE_CLASS_INPUT, ty_uint);
735 m.emit_type_function(
736 fn_ty,
737 ty_void,
738 &[
739 ty_ptr_cross_float,
740 ty_ptr_cross_float,
741 ty_ptr_cross_float,
742 ty_uint,
743 ty_uint,
744 ty_uint,
745 ty_float,
746 ty_float,
747 ],
748 );
749
750 m.emit_constant_u32(ty_uint, c_uint_0, 0);
752 m.emit_constant_u32(ty_uint, c_uint_1, 1);
753 m.emit_constant_f32(ty_float, c_float_0, 0.0);
754 m.emit_constant_u32(ty_uint, c_scope_sg, SCOPE_SUBGROUP);
755
756 m.emit_variable(ty_ptr_input_v3uint, var_gid, STORAGE_CLASS_INPUT);
758 m.emit_variable(ty_ptr_input_uint, var_sg_lid, STORAGE_CLASS_INPUT);
759
760 let label_entry = m.alloc_id();
762 let label_bounds_body = m.alloc_id();
763 let label_bounds_merge = m.alloc_id();
764 let label_loop_header = m.alloc_id();
765 let label_loop_body = m.alloc_id();
766 let label_loop_continue = m.alloc_id();
767 let label_loop_merge = m.alloc_id();
768
769 m.emit_function(ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
771 m.emit_function_parameter(ty_ptr_cross_float, p_a);
772 m.emit_function_parameter(ty_ptr_cross_float, p_b);
773 m.emit_function_parameter(ty_ptr_cross_float, p_c);
774 m.emit_function_parameter(ty_uint, p_m);
775 m.emit_function_parameter(ty_uint, p_n);
776 m.emit_function_parameter(ty_uint, p_k);
777 m.emit_function_parameter(ty_float, p_alpha);
778 m.emit_function_parameter(ty_float, p_beta);
779 m.emit_label(label_entry);
780
781 let gid_val = m.alloc_id();
783 m.emit_load(ty_v3uint, gid_val, var_gid);
784 let gid_x = m.alloc_id();
785 m.emit(OP_COMPOSITE_EXTRACT, &[ty_uint, gid_x, gid_val, 0]);
786
787 let sg_lid = m.alloc_id();
789 m.emit_load(ty_uint, sg_lid, var_sg_lid);
790
791 let total = m.alloc_id();
793 m.emit(OP_I_MUL, &[ty_uint, total, p_m, p_n]);
794
795 let cond = m.alloc_id();
797 m.emit(OP_U_LESS_THAN, &[ty_bool, cond, gid_x, total]);
798 m.emit_selection_merge(label_bounds_merge);
799 m.emit_branch_conditional(cond, label_bounds_body, label_bounds_merge);
800
801 m.emit_label(label_bounds_body);
802
803 let row = m.alloc_id();
806 m.emit(134, &[ty_uint, row, gid_x, p_n]); let col = m.alloc_id();
808 m.emit(137, &[ty_uint, col, gid_x, p_n]); let var_acc = m.alloc_id();
812 m.emit_variable(ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
813 m.emit_store(var_acc, c_float_0);
814 let var_i = m.alloc_id();
815 m.emit_variable(ty_ptr_func_uint, var_i, STORAGE_CLASS_FUNCTION);
816 m.emit_store(var_i, c_uint_0);
817
818 m.emit_branch(label_loop_header);
819
820 m.emit_label(label_loop_header);
822 let i_val = m.alloc_id();
823 m.emit_load(ty_uint, i_val, var_i);
824 let loop_cond = m.alloc_id();
825 m.emit(OP_U_LESS_THAN, &[ty_bool, loop_cond, i_val, p_k]);
826 m.emit_loop_merge(label_loop_merge, label_loop_continue);
827 m.emit_branch_conditional(loop_cond, label_loop_body, label_loop_merge);
828
829 m.emit_label(label_loop_body);
831
832 let a_idx = m.alloc_id();
834 let row_k = m.alloc_id();
835 m.emit(OP_I_MUL, &[ty_uint, row_k, row, p_k]);
836 m.emit(OP_I_ADD, &[ty_uint, a_idx, row_k, i_val]);
837 let a_ptr = m.alloc_id();
838 m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, a_ptr, p_a, a_idx);
839 let a_val = m.alloc_id();
840 m.emit_load(ty_float, a_val, a_ptr);
841
842 let a_broadcast = m.alloc_id();
844 m.emit(
845 OP_GROUP_NON_UNIFORM_SHUFFLE,
846 &[ty_float, a_broadcast, c_scope_sg, a_val, sg_lid],
847 );
848
849 let b_idx = m.alloc_id();
851 let i_n = m.alloc_id();
852 m.emit(OP_I_MUL, &[ty_uint, i_n, i_val, p_n]);
853 m.emit(OP_I_ADD, &[ty_uint, b_idx, i_n, col]);
854 let b_ptr = m.alloc_id();
855 m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, b_ptr, p_b, b_idx);
856 let b_val = m.alloc_id();
857 m.emit_load(ty_float, b_val, b_ptr);
858
859 let prod = m.alloc_id();
861 m.emit(OP_F_MUL, &[ty_float, prod, a_broadcast, b_val]);
862 let old_acc = m.alloc_id();
863 m.emit_load(ty_float, old_acc, var_acc);
864 let new_acc = m.alloc_id();
865 m.emit(OP_F_ADD, &[ty_float, new_acc, old_acc, prod]);
866 m.emit_store(var_acc, new_acc);
867
868 m.emit_branch(label_loop_continue);
869
870 m.emit_label(label_loop_continue);
872 let i_inc = m.alloc_id();
873 m.emit(OP_I_ADD, &[ty_uint, i_inc, i_val, c_uint_1]);
874 m.emit_store(var_i, i_inc);
875 m.emit_branch(label_loop_header);
876
877 m.emit_label(label_loop_merge);
879
880 let final_acc = m.alloc_id();
882 m.emit_load(ty_float, final_acc, var_acc);
883 let alpha_acc = m.alloc_id();
884 m.emit(OP_F_MUL, &[ty_float, alpha_acc, p_alpha, final_acc]);
885
886 let c_ptr = m.alloc_id();
887 m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, c_ptr, p_c, gid_x);
888 let c_old = m.alloc_id();
889 m.emit_load(ty_float, c_old, c_ptr);
890 let beta_c = m.alloc_id();
891 m.emit(OP_F_MUL, &[ty_float, beta_c, p_beta, c_old]);
892 let c_new = m.alloc_id();
893 m.emit(OP_F_ADD, &[ty_float, c_new, alpha_acc, beta_c]);
894 m.emit_store(c_ptr, c_new);
895
896 m.emit_branch(label_bounds_merge);
897
898 m.emit_label(label_bounds_merge);
899 m.emit_return();
900 m.emit_function_end();
901
902 m.finalize()
903}
904
905#[cfg(test)]
908mod tests {
909 use super::*;
910 use crate::spirv::SPIRV_MAGIC;
911
912 const OP_CAPABILITY: u32 = 17;
913
914 fn check_valid_spirv(words: &[u32]) {
915 assert!(words.len() >= 5, "too short for SPIR-V header");
916 assert_eq!(words[0], SPIRV_MAGIC, "bad magic");
917 assert!(words[3] > 0, "ID bound must be > 0");
918 assert_eq!(words[4], 0, "schema must be 0");
919 }
920
921 fn has_capability(words: &[u32], cap: u32) -> bool {
923 let cap_header = (2u32 << 16) | OP_CAPABILITY;
924 words.windows(2).any(|w| w[0] == cap_header && w[1] == cap)
925 }
926
927 fn has_entry_point(words: &[u32], name: &str) -> bool {
929 let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_le_bytes()).collect();
930 let name_bytes = name.as_bytes();
931 bytes.windows(name_bytes.len()).any(|w| w == name_bytes)
932 }
933
934 #[test]
935 fn reduction_subgroup_valid_spirv() {
936 let words = reduction_subgroup_spirv();
937 check_valid_spirv(&words);
938 }
939
940 #[test]
941 fn reduction_subgroup_word_aligned() {
942 let words = reduction_subgroup_spirv();
943 let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_ne_bytes()).collect();
944 assert_eq!(bytes.len() % 4, 0);
945 }
946
947 #[test]
948 fn reduction_subgroup_has_group_non_uniform_capability() {
949 let words = reduction_subgroup_spirv();
950 assert!(
951 has_capability(&words, CAPABILITY_GROUP_NON_UNIFORM),
952 "missing GroupNonUniform capability"
953 );
954 assert!(
955 has_capability(&words, CAPABILITY_GROUP_NON_UNIFORM_ARITHMETIC),
956 "missing GroupNonUniformArithmetic capability"
957 );
958 }
959
960 #[test]
961 fn reduction_subgroup_has_entry_point() {
962 let words = reduction_subgroup_spirv();
963 assert!(
964 has_entry_point(&words, "reduction_subgroup"),
965 "missing entry point name"
966 );
967 }
968
969 #[test]
970 fn scan_subgroup_valid_spirv() {
971 let words = scan_subgroup_spirv();
972 check_valid_spirv(&words);
973 }
974
975 #[test]
976 fn scan_subgroup_word_aligned() {
977 let words = scan_subgroup_spirv();
978 let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_ne_bytes()).collect();
979 assert_eq!(bytes.len() % 4, 0);
980 }
981
982 #[test]
983 fn scan_subgroup_has_group_non_uniform_capability() {
984 let words = scan_subgroup_spirv();
985 assert!(
986 has_capability(&words, CAPABILITY_GROUP_NON_UNIFORM),
987 "missing GroupNonUniform capability"
988 );
989 assert!(
990 has_capability(&words, CAPABILITY_GROUP_NON_UNIFORM_ARITHMETIC),
991 "missing GroupNonUniformArithmetic capability"
992 );
993 }
994
995 #[test]
996 fn scan_subgroup_has_entry_point() {
997 let words = scan_subgroup_spirv();
998 assert!(
999 has_entry_point(&words, "scan_subgroup"),
1000 "missing entry point name"
1001 );
1002 }
1003
1004 #[test]
1005 fn gemm_subgroup_valid_spirv() {
1006 let words = gemm_subgroup_spirv();
1007 check_valid_spirv(&words);
1008 }
1009
1010 #[test]
1011 fn gemm_subgroup_word_aligned() {
1012 let words = gemm_subgroup_spirv();
1013 let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_ne_bytes()).collect();
1014 assert_eq!(bytes.len() % 4, 0);
1015 }
1016
1017 #[test]
1018 fn gemm_subgroup_has_group_non_uniform_capability() {
1019 let words = gemm_subgroup_spirv();
1020 assert!(
1021 has_capability(&words, CAPABILITY_GROUP_NON_UNIFORM),
1022 "missing GroupNonUniform capability"
1023 );
1024 assert!(
1025 has_capability(&words, CAPABILITY_GROUP_NON_UNIFORM_SHUFFLE),
1026 "missing GroupNonUniformShuffle capability"
1027 );
1028 }
1029
1030 #[test]
1031 fn gemm_subgroup_has_entry_point() {
1032 let words = gemm_subgroup_spirv();
1033 assert!(
1034 has_entry_point(&words, "gemm_subgroup"),
1035 "missing entry point name"
1036 );
1037 }
1038
1039 #[test]
1040 fn subgroup_shaders_all_word_aligned() {
1041 fn to_bytes(words: &[u32]) -> Vec<u8> {
1042 words.iter().flat_map(|w| w.to_ne_bytes()).collect()
1043 }
1044 assert_eq!(to_bytes(&reduction_subgroup_spirv()).len() % 4, 0);
1045 assert_eq!(to_bytes(&scan_subgroup_spirv()).len() % 4, 0);
1046 assert_eq!(to_bytes(&gemm_subgroup_spirv()).len() % 4, 0);
1047 }
1048}