1use oxicuda_backend::{BinaryOp, ReduceOp, UnaryOp};
17
18pub const SPIRV_MAGIC: u32 = 0x07230203;
22pub const SPIRV_VERSION_1_2: u32 = 0x0001_0200;
24pub const SPIRV_GENERATOR: u32 = 0x000D_0002;
26
27pub(crate) const OP_EXT_INST_IMPORT: u32 = 11;
30pub(crate) const OP_EXT_INST: u32 = 12;
31pub(crate) const OP_MEMORY_MODEL: u32 = 14;
32pub(crate) const OP_ENTRY_POINT: u32 = 15;
33pub(crate) const OP_EXECUTION_MODE: u32 = 16;
34pub(crate) const OP_CAPABILITY: u32 = 17;
35pub(crate) const OP_TYPE_VOID: u32 = 19;
36pub(crate) const OP_TYPE_BOOL: u32 = 20;
37pub(crate) const OP_TYPE_INT: u32 = 21;
38pub(crate) const OP_TYPE_FLOAT: u32 = 22;
39pub(crate) const OP_TYPE_VECTOR: u32 = 23;
40pub(crate) const OP_TYPE_ARRAY: u32 = 28;
41pub(crate) const OP_TYPE_POINTER: u32 = 32;
42pub(crate) const OP_TYPE_FUNCTION: u32 = 33;
43pub(crate) const OP_CONSTANT: u32 = 43;
44pub(crate) const OP_FUNCTION: u32 = 54;
45pub(crate) const OP_FUNCTION_PARAMETER: u32 = 55;
46pub(crate) const OP_FUNCTION_END: u32 = 56;
47pub(crate) const OP_VARIABLE: u32 = 59;
48pub(crate) const OP_LOAD: u32 = 61;
49pub(crate) const OP_STORE: u32 = 62;
50pub(crate) const OP_IN_BOUNDS_PTR_ACCESS_CHAIN: u32 = 70;
51pub(crate) const OP_DECORATE: u32 = 71;
52pub(crate) const OP_COMPOSITE_EXTRACT: u32 = 81;
53pub(crate) const OP_CONVERT_U_TO_F: u32 = 112;
54pub(crate) const OP_F_NEGATE: u32 = 127;
55pub(crate) const OP_I_ADD: u32 = 128;
56pub(crate) const OP_F_ADD: u32 = 129;
57pub(crate) const OP_F_SUB: u32 = 131;
58pub(crate) const OP_I_MUL: u32 = 132;
59pub(crate) const OP_F_MUL: u32 = 133;
60pub(crate) const OP_U_DIV: u32 = 134;
61pub(crate) const OP_F_DIV: u32 = 136;
62pub(crate) const OP_U_MOD: u32 = 137;
63pub(crate) const OP_U_LESS_THAN: u32 = 176;
64pub(crate) const OP_LOOP_MERGE: u32 = 246;
65pub(crate) const OP_SELECTION_MERGE: u32 = 247;
66pub(crate) const OP_LABEL: u32 = 248;
67pub(crate) const OP_BRANCH: u32 = 249;
68pub(crate) const OP_BRANCH_CONDITIONAL: u32 = 250;
69pub(crate) const OP_CONTROL_BARRIER: u32 = 224;
70pub(crate) const OP_PHI: u32 = 245;
71pub(crate) const OP_RETURN: u32 = 253;
72
73pub(crate) const OP_GROUP_NON_UNIFORM_FADD: u32 = 350;
75pub(crate) const OP_GROUP_NON_UNIFORM_SHUFFLE: u32 = 345;
76
77const CAPABILITY_SHADER: u32 = 1;
79const CAPABILITY_ADDRESSES: u32 = 4;
80const CAPABILITY_KERNEL: u32 = 6;
81
82const ADDRESSING_MODEL_LOGICAL: u32 = 0;
84const ADDRESSING_MODEL_PHYSICAL64: u32 = 2;
85const MEMORY_MODEL_GLSL450: u32 = 1;
86const MEMORY_MODEL_OPENCL: u32 = 2;
87
88const EXECUTION_MODEL_GLCOMPUTE: u32 = 5;
90pub(crate) const EXECUTION_MODEL_KERNEL: u32 = 6;
91const EXECUTION_MODE_LOCAL_SIZE: u32 = 17;
92
93pub(crate) const FUNCTION_CONTROL_NONE: u32 = 0;
95
96const DECORATION_BUILTIN: u32 = 11;
98
99const BUILTIN_GLOBAL_INVOCATION_ID: u32 = 28;
101
102const STORAGE_CLASS_INPUT: u32 = 1;
104const STORAGE_CLASS_CROSS_WORKGROUP: u32 = 5;
105pub(crate) const STORAGE_CLASS_FUNCTION: u32 = 7;
106
107const SELECTION_CONTROL_NONE: u32 = 0;
109const LOOP_CONTROL_NONE: u32 = 0;
110
111pub(crate) const OPENCL_EXP: u32 = 19;
113const OPENCL_FABS: u32 = 23;
114pub(crate) const OPENCL_FMAX: u32 = 27;
115const OPENCL_FMIN: u32 = 28;
116const OPENCL_LOG: u32 = 37;
117const OPENCL_SQRT: u32 = 61;
118const OPENCL_TANH: u32 = 63;
119
120pub(crate) const WORKGROUP_SIZE: u32 = 256;
122
123pub struct SpvModule {
130 words: Vec<u32>,
131 id_bound: u32,
133}
134
135impl SpvModule {
136 pub fn new() -> Self {
138 let words = vec![SPIRV_MAGIC, SPIRV_VERSION_1_2, SPIRV_GENERATOR, 0, 0];
139 Self { words, id_bound: 1 }
140 }
141
142 pub fn alloc_id(&mut self) -> u32 {
144 let id = self.id_bound;
145 self.id_bound += 1;
146 id
147 }
148
149 pub fn emit(&mut self, opcode: u32, operands: &[u32]) {
151 let word_count = (1 + operands.len()) as u32;
152 self.words.push((word_count << 16) | opcode);
153 self.words.extend_from_slice(operands);
154 }
155
156 pub fn string_words(s: &str) -> Vec<u32> {
158 let bytes = s.as_bytes();
159 let padded_len = (bytes.len() + 4) & !3;
160 let mut out = vec![0u32; padded_len / 4];
161 for (i, &b) in bytes.iter().enumerate() {
162 out[i / 4] |= (b as u32) << ((i % 4) * 8);
163 }
164 out
165 }
166
167 pub fn finalize(mut self) -> Vec<u32> {
169 self.words[3] = self.id_bound;
170 self.words
171 }
172
173 pub(crate) fn emit_capability(&mut self, cap: u32) {
176 self.emit(OP_CAPABILITY, &[cap]);
177 }
178
179 pub(crate) fn emit_ext_inst_import(&mut self, id: u32, name: &str) {
180 let mut ops = vec![id];
181 ops.extend(Self::string_words(name));
182 self.emit(OP_EXT_INST_IMPORT, &ops);
183 }
184
185 pub(crate) fn emit_memory_model(&mut self, addressing: u32, memory: u32) {
186 self.emit(OP_MEMORY_MODEL, &[addressing, memory]);
187 }
188
189 pub(crate) fn emit_entry_point(
190 &mut self,
191 model: u32,
192 func_id: u32,
193 name: &str,
194 interfaces: &[u32],
195 ) {
196 let mut ops = vec![model, func_id];
197 ops.extend(Self::string_words(name));
198 ops.extend_from_slice(interfaces);
199 self.emit(OP_ENTRY_POINT, &ops);
200 }
201
202 pub(crate) fn emit_execution_mode_local_size(&mut self, func_id: u32, x: u32, y: u32, z: u32) {
203 self.emit(
204 OP_EXECUTION_MODE,
205 &[func_id, EXECUTION_MODE_LOCAL_SIZE, x, y, z],
206 );
207 }
208
209 pub(crate) fn emit_decorate(&mut self, target: u32, decoration: u32, operands: &[u32]) {
210 let mut ops = vec![target, decoration];
211 ops.extend_from_slice(operands);
212 self.emit(OP_DECORATE, &ops);
213 }
214
215 pub(crate) fn emit_type_void(&mut self, id: u32) {
216 self.emit(OP_TYPE_VOID, &[id]);
217 }
218
219 pub(crate) fn emit_type_bool(&mut self, id: u32) {
220 self.emit(OP_TYPE_BOOL, &[id]);
221 }
222
223 pub(crate) fn emit_type_int(&mut self, id: u32, width: u32, signedness: u32) {
224 self.emit(OP_TYPE_INT, &[id, width, signedness]);
225 }
226
227 pub(crate) fn emit_type_float(&mut self, id: u32, width: u32) {
228 self.emit(OP_TYPE_FLOAT, &[id, width]);
229 }
230
231 pub(crate) fn emit_type_vector(&mut self, id: u32, component: u32, count: u32) {
232 self.emit(OP_TYPE_VECTOR, &[id, component, count]);
233 }
234
235 pub(crate) fn emit_type_pointer(&mut self, id: u32, storage_class: u32, pointee: u32) {
236 self.emit(OP_TYPE_POINTER, &[id, storage_class, pointee]);
237 }
238
239 pub(crate) fn emit_type_function(&mut self, id: u32, return_type: u32, params: &[u32]) {
240 let mut ops = vec![id, return_type];
241 ops.extend_from_slice(params);
242 self.emit(OP_TYPE_FUNCTION, &ops);
243 }
244
245 pub(crate) fn emit_constant_u32(&mut self, ty: u32, id: u32, value: u32) {
246 self.emit(OP_CONSTANT, &[ty, id, value]);
247 }
248
249 pub(crate) fn emit_constant_f32(&mut self, ty: u32, id: u32, value: f32) {
250 self.emit(OP_CONSTANT, &[ty, id, value.to_bits()]);
251 }
252
253 pub(crate) fn emit_variable(&mut self, ty: u32, id: u32, storage_class: u32) {
254 self.emit(OP_VARIABLE, &[ty, id, storage_class]);
255 }
256
257 pub(crate) fn emit_load(&mut self, result_ty: u32, result: u32, pointer: u32) {
258 self.emit(OP_LOAD, &[result_ty, result, pointer]);
259 }
260
261 pub(crate) fn emit_store(&mut self, pointer: u32, value: u32) {
262 self.emit(OP_STORE, &[pointer, value]);
263 }
264
265 pub(crate) fn emit_in_bounds_ptr_access_chain(
266 &mut self,
267 result_ty: u32,
268 result: u32,
269 base: u32,
270 element: u32,
271 ) {
272 self.emit(
273 OP_IN_BOUNDS_PTR_ACCESS_CHAIN,
274 &[result_ty, result, base, element],
275 );
276 }
277
278 pub(crate) fn emit_function(&mut self, result_ty: u32, result: u32, control: u32, fn_ty: u32) {
279 self.emit(OP_FUNCTION, &[result_ty, result, control, fn_ty]);
280 }
281
282 pub(crate) fn emit_function_parameter(&mut self, result_ty: u32, result: u32) {
283 self.emit(OP_FUNCTION_PARAMETER, &[result_ty, result]);
284 }
285
286 pub(crate) fn emit_label(&mut self, id: u32) {
287 self.emit(OP_LABEL, &[id]);
288 }
289
290 pub(crate) fn emit_return(&mut self) {
291 self.emit(OP_RETURN, &[]);
292 }
293
294 pub(crate) fn emit_function_end(&mut self) {
295 self.emit(OP_FUNCTION_END, &[]);
296 }
297
298 pub(crate) fn emit_branch(&mut self, target: u32) {
299 self.emit(OP_BRANCH, &[target]);
300 }
301
302 pub(crate) fn emit_branch_conditional(&mut self, cond: u32, true_label: u32, false_label: u32) {
303 self.emit(OP_BRANCH_CONDITIONAL, &[cond, true_label, false_label]);
304 }
305
306 pub(crate) fn emit_selection_merge(&mut self, merge_label: u32) {
307 self.emit(OP_SELECTION_MERGE, &[merge_label, SELECTION_CONTROL_NONE]);
308 }
309
310 pub(crate) fn emit_loop_merge(&mut self, merge_label: u32, continue_label: u32) {
311 self.emit(
312 OP_LOOP_MERGE,
313 &[merge_label, continue_label, LOOP_CONTROL_NONE],
314 );
315 }
316
317 pub(crate) fn emit_opencl_ext(
318 &mut self,
319 ext_id: u32,
320 result_ty: u32,
321 result: u32,
322 inst: u32,
323 args: &[u32],
324 ) {
325 let mut ops = vec![result_ty, result, ext_id, inst];
326 ops.extend_from_slice(args);
327 self.emit(OP_EXT_INST, &ops);
328 }
329}
330
331impl Default for SpvModule {
332 fn default() -> Self {
333 Self::new()
334 }
335}
336
337pub(crate) struct BaseIds {
341 pub(crate) ty_void: u32,
342 pub(crate) ty_bool: u32,
343 pub(crate) ty_uint: u32,
344 pub(crate) ty_float: u32,
345 #[allow(dead_code)]
346 pub(crate) ty_v3uint: u32,
347 #[allow(dead_code)]
348 pub(crate) ty_fn_void: u32,
349 #[allow(dead_code)]
350 pub(crate) ty_ptr_input_v3uint: u32,
351 pub(crate) ty_ptr_cross_float: u32,
352 pub(crate) ty_ptr_func_float: u32,
353 pub(crate) ty_ptr_func_uint: u32,
354 pub(crate) c_uint_0: u32,
355 pub(crate) c_uint_1: u32,
356 pub(crate) c_float_0: u32,
357 pub(crate) c_float_1: u32,
358 pub(crate) var_gid: u32,
359 pub(crate) opencl_ext: u32,
360}
361
362pub(crate) fn emit_preamble(m: &mut SpvModule) -> BaseIds {
368 let ty_void = m.alloc_id();
369 let ty_bool = m.alloc_id();
370 let ty_uint = m.alloc_id();
371 let ty_float = m.alloc_id();
372 let ty_v3uint = m.alloc_id();
373 let ty_fn_void = m.alloc_id();
374 let ty_ptr_input_v3uint = m.alloc_id();
375 let ty_ptr_cross_float = m.alloc_id();
376 let ty_ptr_func_float = m.alloc_id();
377 let ty_ptr_func_uint = m.alloc_id();
378 let c_uint_0 = m.alloc_id();
379 let c_uint_1 = m.alloc_id();
380 let c_float_0 = m.alloc_id();
381 let c_float_1 = m.alloc_id();
382 let var_gid = m.alloc_id();
383 let opencl_ext = m.alloc_id();
384
385 m.emit_capability(CAPABILITY_KERNEL);
387 m.emit_capability(CAPABILITY_ADDRESSES);
388
389 m.emit_ext_inst_import(opencl_ext, "OpenCL.std");
391
392 m.emit_memory_model(ADDRESSING_MODEL_PHYSICAL64, MEMORY_MODEL_OPENCL);
394
395 m.emit_decorate(var_gid, DECORATION_BUILTIN, &[BUILTIN_GLOBAL_INVOCATION_ID]);
400
401 m.emit_type_void(ty_void);
403 m.emit_type_bool(ty_bool);
404 m.emit_type_int(ty_uint, 32, 0);
405 m.emit_type_float(ty_float, 32);
406 m.emit_type_vector(ty_v3uint, ty_uint, 3);
407 m.emit_type_function(ty_fn_void, ty_void, &[]);
408 m.emit_type_pointer(ty_ptr_input_v3uint, STORAGE_CLASS_INPUT, ty_v3uint);
409 m.emit_type_pointer(ty_ptr_cross_float, STORAGE_CLASS_CROSS_WORKGROUP, ty_float);
410 m.emit_type_pointer(ty_ptr_func_float, STORAGE_CLASS_FUNCTION, ty_float);
411 m.emit_type_pointer(ty_ptr_func_uint, STORAGE_CLASS_FUNCTION, ty_uint);
412
413 m.emit_constant_u32(ty_uint, c_uint_0, 0);
415 m.emit_constant_u32(ty_uint, c_uint_1, 1);
416 m.emit_constant_f32(ty_float, c_float_0, 0.0);
417 m.emit_constant_f32(ty_float, c_float_1, 1.0);
418
419 m.emit_variable(ty_ptr_input_v3uint, var_gid, STORAGE_CLASS_INPUT);
421
422 BaseIds {
423 ty_void,
424 ty_bool,
425 ty_uint,
426 ty_float,
427 ty_v3uint,
428 ty_fn_void,
429 ty_ptr_input_v3uint,
430 ty_ptr_cross_float,
431 ty_ptr_func_float,
432 ty_ptr_func_uint,
433 c_uint_0,
434 c_uint_1,
435 c_float_0,
436 c_float_1,
437 var_gid,
438 opencl_ext,
439 }
440}
441
442pub(crate) fn load_gid_x(m: &mut SpvModule, b: &BaseIds) -> u32 {
444 let gid_val = m.alloc_id();
445 m.emit_load(b.ty_v3uint, gid_val, b.var_gid);
446 let gid_x = m.alloc_id();
447 m.emit(OP_COMPOSITE_EXTRACT, &[b.ty_uint, gid_x, gid_val, 0]);
448 gid_x
449}
450
451pub fn unary_compute_shader(op: UnaryOp) -> Vec<u32> {
457 let mut m = SpvModule::new();
458 let b = emit_preamble(&mut m);
459
460 let main_fn = m.alloc_id();
461 let fn_ty = m.alloc_id();
462 let p_input = m.alloc_id();
463 let p_output = m.alloc_id();
464 let p_count = m.alloc_id();
465
466 m.emit_type_function(
468 fn_ty,
469 b.ty_void,
470 &[b.ty_ptr_cross_float, b.ty_ptr_cross_float, b.ty_uint],
471 );
472
473 m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
475 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
476
477 let label_entry = m.alloc_id();
479 let label_body = m.alloc_id();
480 let label_merge = m.alloc_id();
481
482 m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
484 m.emit_function_parameter(b.ty_ptr_cross_float, p_input);
485 m.emit_function_parameter(b.ty_ptr_cross_float, p_output);
486 m.emit_function_parameter(b.ty_uint, p_count);
487 m.emit_label(label_entry);
488
489 let gid = load_gid_x(&mut m, &b);
490
491 let cond = m.alloc_id();
493 m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, p_count]);
494 m.emit_selection_merge(label_merge);
495 m.emit_branch_conditional(cond, label_body, label_merge);
496
497 m.emit_label(label_body);
498
499 let inp_ptr = m.alloc_id();
501 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, inp_ptr, p_input, gid);
502 let inp_val = m.alloc_id();
503 m.emit_load(b.ty_float, inp_val, inp_ptr);
504
505 let result = emit_unary_op(&mut m, &b, op, inp_val);
506
507 let out_ptr = m.alloc_id();
509 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_output, gid);
510 m.emit_store(out_ptr, result);
511
512 m.emit_branch(label_merge);
513
514 m.emit_label(label_merge);
515 m.emit_return();
516 m.emit_function_end();
517
518 m.finalize()
519}
520
521fn emit_unary_op(m: &mut SpvModule, b: &BaseIds, op: UnaryOp, x: u32) -> u32 {
523 let result = m.alloc_id();
524 match op {
525 UnaryOp::Relu => {
526 m.emit_opencl_ext(
527 b.opencl_ext,
528 b.ty_float,
529 result,
530 OPENCL_FMAX,
531 &[b.c_float_0, x],
532 );
533 }
534 UnaryOp::Sigmoid => {
535 let neg_x = m.alloc_id();
536 m.emit(OP_F_NEGATE, &[b.ty_float, neg_x, x]);
537 let exp_neg_x = m.alloc_id();
538 m.emit_opencl_ext(b.opencl_ext, b.ty_float, exp_neg_x, OPENCL_EXP, &[neg_x]);
539 let one_plus = m.alloc_id();
540 m.emit(OP_F_ADD, &[b.ty_float, one_plus, b.c_float_1, exp_neg_x]);
541 m.emit(OP_F_DIV, &[b.ty_float, result, b.c_float_1, one_plus]);
542 }
543 UnaryOp::Tanh => {
544 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_TANH, &[x]);
545 }
546 UnaryOp::Exp => {
547 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_EXP, &[x]);
548 }
549 UnaryOp::Log => {
550 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_LOG, &[x]);
551 }
552 UnaryOp::Sqrt => {
553 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_SQRT, &[x]);
554 }
555 UnaryOp::Abs => {
556 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_FABS, &[x]);
557 }
558 UnaryOp::Neg => {
559 m.emit(OP_F_NEGATE, &[b.ty_float, result, x]);
560 }
561 }
562 result
563}
564
565pub fn binary_compute_shader(op: BinaryOp) -> Vec<u32> {
572 let mut m = SpvModule::new();
573 let b = emit_preamble(&mut m);
574
575 let main_fn = m.alloc_id();
576 let fn_ty = m.alloc_id();
577 let p_a = m.alloc_id();
578 let p_b = m.alloc_id();
579 let p_out = m.alloc_id();
580 let p_count = m.alloc_id();
581
582 m.emit_type_function(
584 fn_ty,
585 b.ty_void,
586 &[
587 b.ty_ptr_cross_float,
588 b.ty_ptr_cross_float,
589 b.ty_ptr_cross_float,
590 b.ty_uint,
591 ],
592 );
593
594 m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
595 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
596
597 let label_entry = m.alloc_id();
598 let label_body = m.alloc_id();
599 let label_merge = m.alloc_id();
600
601 m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
602 m.emit_function_parameter(b.ty_ptr_cross_float, p_a);
603 m.emit_function_parameter(b.ty_ptr_cross_float, p_b);
604 m.emit_function_parameter(b.ty_ptr_cross_float, p_out);
605 m.emit_function_parameter(b.ty_uint, p_count);
606 m.emit_label(label_entry);
607
608 let gid = load_gid_x(&mut m, &b);
609
610 let cond = m.alloc_id();
611 m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, p_count]);
612 m.emit_selection_merge(label_merge);
613 m.emit_branch_conditional(cond, label_body, label_merge);
614
615 m.emit_label(label_body);
616
617 let a_ptr = m.alloc_id();
618 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, a_ptr, p_a, gid);
619 let a_val = m.alloc_id();
620 m.emit_load(b.ty_float, a_val, a_ptr);
621
622 let b_ptr = m.alloc_id();
623 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, b_ptr, p_b, gid);
624 let b_val = m.alloc_id();
625 m.emit_load(b.ty_float, b_val, b_ptr);
626
627 let result = emit_binary_op(&mut m, &b, op, a_val, b_val);
628
629 let out_ptr = m.alloc_id();
630 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_out, gid);
631 m.emit_store(out_ptr, result);
632
633 m.emit_branch(label_merge);
634
635 m.emit_label(label_merge);
636 m.emit_return();
637 m.emit_function_end();
638
639 m.finalize()
640}
641
642fn emit_binary_op(m: &mut SpvModule, b: &BaseIds, op: BinaryOp, lhs: u32, rhs: u32) -> u32 {
643 let result = m.alloc_id();
644 match op {
645 BinaryOp::Add => m.emit(OP_F_ADD, &[b.ty_float, result, lhs, rhs]),
646 BinaryOp::Sub => m.emit(OP_F_SUB, &[b.ty_float, result, lhs, rhs]),
647 BinaryOp::Mul => m.emit(OP_F_MUL, &[b.ty_float, result, lhs, rhs]),
648 BinaryOp::Div => m.emit(OP_F_DIV, &[b.ty_float, result, lhs, rhs]),
649 BinaryOp::Max => {
650 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_FMAX, &[lhs, rhs]);
651 }
652 BinaryOp::Min => {
653 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_FMIN, &[lhs, rhs]);
654 }
655 }
656 result
657}
658
659pub fn reduce_compute_shader(op: ReduceOp) -> Vec<u32> {
668 let mut m = SpvModule::new();
669 let b = emit_preamble(&mut m);
670
671 let main_fn = m.alloc_id();
672 let fn_ty = m.alloc_id();
673 let p_input = m.alloc_id();
674 let p_output = m.alloc_id();
675 let p_outer = m.alloc_id();
676 let p_reduce = m.alloc_id();
677 let p_inner = m.alloc_id();
678
679 m.emit_type_function(
681 fn_ty,
682 b.ty_void,
683 &[
684 b.ty_ptr_cross_float,
685 b.ty_ptr_cross_float,
686 b.ty_uint,
687 b.ty_uint,
688 b.ty_uint,
689 ],
690 );
691
692 m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
693 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
694
695 let label_entry = m.alloc_id();
696 let label_bounds_body = m.alloc_id();
697 let label_bounds_merge = m.alloc_id();
698 let label_loop_header = m.alloc_id();
699 let label_loop_body = m.alloc_id();
700 let label_loop_continue = m.alloc_id();
701 let label_loop_merge = m.alloc_id();
702
703 m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
704 m.emit_function_parameter(b.ty_ptr_cross_float, p_input);
705 m.emit_function_parameter(b.ty_ptr_cross_float, p_output);
706 m.emit_function_parameter(b.ty_uint, p_outer);
707 m.emit_function_parameter(b.ty_uint, p_reduce);
708 m.emit_function_parameter(b.ty_uint, p_inner);
709 m.emit_label(label_entry);
710
711 let gid = load_gid_x(&mut m, &b);
712
713 let total_output = m.alloc_id();
715 m.emit(OP_I_MUL, &[b.ty_uint, total_output, p_outer, p_inner]);
716
717 let cond_bounds = m.alloc_id();
719 m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond_bounds, gid, total_output]);
720 m.emit_selection_merge(label_bounds_merge);
721 m.emit_branch_conditional(cond_bounds, label_bounds_body, label_bounds_merge);
722
723 m.emit_label(label_bounds_body);
724
725 let outer_idx = m.alloc_id();
727 m.emit(OP_U_DIV, &[b.ty_uint, outer_idx, gid, p_inner]);
728 let inner_idx = m.alloc_id();
729 m.emit(OP_U_MOD, &[b.ty_uint, inner_idx, gid, p_inner]);
730
731 let t1 = m.alloc_id();
733 m.emit(OP_I_MUL, &[b.ty_uint, t1, outer_idx, p_reduce]);
734 let t2 = m.alloc_id();
735 m.emit(OP_I_MUL, &[b.ty_uint, t2, t1, p_inner]);
736 let base_idx = m.alloc_id();
737 m.emit(OP_I_ADD, &[b.ty_uint, base_idx, t2, inner_idx]);
738
739 let var_i = m.alloc_id();
741 m.emit_variable(b.ty_ptr_func_uint, var_i, STORAGE_CLASS_FUNCTION);
742 m.emit_store(var_i, b.c_uint_0);
743
744 let var_acc = m.alloc_id();
746 m.emit_variable(b.ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
747 let init_val = match op {
748 ReduceOp::Sum | ReduceOp::Mean => b.c_float_0,
749 ReduceOp::Max => {
750 let neg_inf = m.alloc_id();
751 m.emit_constant_f32(b.ty_float, neg_inf, f32::NEG_INFINITY);
752 neg_inf
753 }
754 ReduceOp::Min => {
755 let pos_inf = m.alloc_id();
756 m.emit_constant_f32(b.ty_float, pos_inf, f32::INFINITY);
757 pos_inf
758 }
759 };
760 m.emit_store(var_acc, init_val);
761
762 m.emit_branch(label_loop_header);
763
764 m.emit_label(label_loop_header);
766 let i_val = m.alloc_id();
767 m.emit_load(b.ty_uint, i_val, var_i);
768 let loop_cond = m.alloc_id();
769 m.emit(OP_U_LESS_THAN, &[b.ty_bool, loop_cond, i_val, p_reduce]);
770 m.emit_loop_merge(label_loop_merge, label_loop_continue);
771 m.emit_branch_conditional(loop_cond, label_loop_body, label_loop_merge);
772
773 m.emit_label(label_loop_body);
775
776 let i_times_inner = m.alloc_id();
778 m.emit(OP_I_MUL, &[b.ty_uint, i_times_inner, i_val, p_inner]);
779 let input_idx = m.alloc_id();
780 m.emit(OP_I_ADD, &[b.ty_uint, input_idx, base_idx, i_times_inner]);
781
782 let inp_ptr = m.alloc_id();
783 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, inp_ptr, p_input, input_idx);
784 let inp_val = m.alloc_id();
785 m.emit_load(b.ty_float, inp_val, inp_ptr);
786
787 let acc_val = m.alloc_id();
788 m.emit_load(b.ty_float, acc_val, var_acc);
789
790 let new_acc = m.alloc_id();
791 match op {
792 ReduceOp::Sum | ReduceOp::Mean => {
793 m.emit(OP_F_ADD, &[b.ty_float, new_acc, acc_val, inp_val]);
794 }
795 ReduceOp::Max => {
796 m.emit_opencl_ext(
797 b.opencl_ext,
798 b.ty_float,
799 new_acc,
800 OPENCL_FMAX,
801 &[acc_val, inp_val],
802 );
803 }
804 ReduceOp::Min => {
805 m.emit_opencl_ext(
806 b.opencl_ext,
807 b.ty_float,
808 new_acc,
809 OPENCL_FMIN,
810 &[acc_val, inp_val],
811 );
812 }
813 }
814 m.emit_store(var_acc, new_acc);
815
816 m.emit_branch(label_loop_continue);
817
818 m.emit_label(label_loop_continue);
820 let i_inc = m.alloc_id();
821 m.emit(OP_I_ADD, &[b.ty_uint, i_inc, i_val, b.c_uint_1]);
822 m.emit_store(var_i, i_inc);
823 m.emit_branch(label_loop_header);
824
825 m.emit_label(label_loop_merge);
827
828 let final_acc = m.alloc_id();
829 m.emit_load(b.ty_float, final_acc, var_acc);
830
831 let store_val = if op == ReduceOp::Mean {
832 let reduce_f = m.alloc_id();
833 m.emit(OP_CONVERT_U_TO_F, &[b.ty_float, reduce_f, p_reduce]);
834 let mean_val = m.alloc_id();
835 m.emit(OP_F_DIV, &[b.ty_float, mean_val, final_acc, reduce_f]);
836 mean_val
837 } else {
838 final_acc
839 };
840
841 let out_ptr = m.alloc_id();
842 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_output, gid);
843 m.emit_store(out_ptr, store_val);
844
845 m.emit_branch(label_bounds_merge);
846
847 m.emit_label(label_bounds_merge);
848 m.emit_return();
849 m.emit_function_end();
850
851 m.finalize()
852}
853
854pub fn gemm_compute_shader() -> Vec<u32> {
864 let mut m = SpvModule::new();
865 let b = emit_preamble(&mut m);
866
867 let main_fn = m.alloc_id();
868 let fn_ty = m.alloc_id();
869 let p_a = m.alloc_id();
870 let p_b = m.alloc_id();
871 let p_c = m.alloc_id();
872 let p_m = m.alloc_id();
873 let p_n = m.alloc_id();
874 let p_k = m.alloc_id();
875 let p_alpha = m.alloc_id();
876 let p_beta = m.alloc_id();
877
878 m.emit_type_function(
880 fn_ty,
881 b.ty_void,
882 &[
883 b.ty_ptr_cross_float,
884 b.ty_ptr_cross_float,
885 b.ty_ptr_cross_float,
886 b.ty_uint,
887 b.ty_uint,
888 b.ty_uint,
889 b.ty_float,
890 b.ty_float,
891 ],
892 );
893
894 m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
895 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
896
897 let label_entry = m.alloc_id();
898 let label_bounds_body = m.alloc_id();
899 let label_bounds_merge = m.alloc_id();
900 let label_loop_header = m.alloc_id();
901 let label_loop_body = m.alloc_id();
902 let label_loop_continue = m.alloc_id();
903 let label_loop_merge = m.alloc_id();
904
905 m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
906 m.emit_function_parameter(b.ty_ptr_cross_float, p_a);
907 m.emit_function_parameter(b.ty_ptr_cross_float, p_b);
908 m.emit_function_parameter(b.ty_ptr_cross_float, p_c);
909 m.emit_function_parameter(b.ty_uint, p_m);
910 m.emit_function_parameter(b.ty_uint, p_n);
911 m.emit_function_parameter(b.ty_uint, p_k);
912 m.emit_function_parameter(b.ty_float, p_alpha);
913 m.emit_function_parameter(b.ty_float, p_beta);
914 m.emit_label(label_entry);
915
916 let gid = load_gid_x(&mut m, &b);
917
918 let total = m.alloc_id();
920 m.emit(OP_I_MUL, &[b.ty_uint, total, p_m, p_n]);
921
922 let cond = m.alloc_id();
924 m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, total]);
925 m.emit_selection_merge(label_bounds_merge);
926 m.emit_branch_conditional(cond, label_bounds_body, label_bounds_merge);
927
928 m.emit_label(label_bounds_body);
929
930 let row = m.alloc_id();
932 m.emit(OP_U_DIV, &[b.ty_uint, row, gid, p_n]);
933 let col = m.alloc_id();
934 m.emit(OP_U_MOD, &[b.ty_uint, col, gid, p_n]);
935
936 let var_i = m.alloc_id();
938 m.emit_variable(b.ty_ptr_func_uint, var_i, STORAGE_CLASS_FUNCTION);
939 m.emit_store(var_i, b.c_uint_0);
940 let var_acc = m.alloc_id();
941 m.emit_variable(b.ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
942 m.emit_store(var_acc, b.c_float_0);
943
944 m.emit_branch(label_loop_header);
945
946 m.emit_label(label_loop_header);
948 let i_val = m.alloc_id();
949 m.emit_load(b.ty_uint, i_val, var_i);
950 let loop_cond = m.alloc_id();
951 m.emit(OP_U_LESS_THAN, &[b.ty_bool, loop_cond, i_val, p_k]);
952 m.emit_loop_merge(label_loop_merge, label_loop_continue);
953 m.emit_branch_conditional(loop_cond, label_loop_body, label_loop_merge);
954
955 m.emit_label(label_loop_body);
957
958 let row_k = m.alloc_id();
960 m.emit(OP_I_MUL, &[b.ty_uint, row_k, row, p_k]);
961 let a_idx = m.alloc_id();
962 m.emit(OP_I_ADD, &[b.ty_uint, a_idx, row_k, i_val]);
963
964 let i_n = m.alloc_id();
966 m.emit(OP_I_MUL, &[b.ty_uint, i_n, i_val, p_n]);
967 let b_idx = m.alloc_id();
968 m.emit(OP_I_ADD, &[b.ty_uint, b_idx, i_n, col]);
969
970 let a_ptr = m.alloc_id();
971 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, a_ptr, p_a, a_idx);
972 let a_val = m.alloc_id();
973 m.emit_load(b.ty_float, a_val, a_ptr);
974
975 let b_ptr = m.alloc_id();
976 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, b_ptr, p_b, b_idx);
977 let b_val = m.alloc_id();
978 m.emit_load(b.ty_float, b_val, b_ptr);
979
980 let prod = m.alloc_id();
981 m.emit(OP_F_MUL, &[b.ty_float, prod, a_val, b_val]);
982 let old_acc = m.alloc_id();
983 m.emit_load(b.ty_float, old_acc, var_acc);
984 let new_acc = m.alloc_id();
985 m.emit(OP_F_ADD, &[b.ty_float, new_acc, old_acc, prod]);
986 m.emit_store(var_acc, new_acc);
987
988 m.emit_branch(label_loop_continue);
989
990 m.emit_label(label_loop_continue);
992 let i_inc = m.alloc_id();
993 m.emit(OP_I_ADD, &[b.ty_uint, i_inc, i_val, b.c_uint_1]);
994 m.emit_store(var_i, i_inc);
995 m.emit_branch(label_loop_header);
996
997 m.emit_label(label_loop_merge);
999
1000 let final_acc = m.alloc_id();
1002 m.emit_load(b.ty_float, final_acc, var_acc);
1003 let alpha_acc = m.alloc_id();
1004 m.emit(OP_F_MUL, &[b.ty_float, alpha_acc, p_alpha, final_acc]);
1005
1006 let c_ptr = m.alloc_id();
1007 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, c_ptr, p_c, gid);
1008 let c_old = m.alloc_id();
1009 m.emit_load(b.ty_float, c_old, c_ptr);
1010 let beta_c = m.alloc_id();
1011 m.emit(OP_F_MUL, &[b.ty_float, beta_c, p_beta, c_old]);
1012 let c_new = m.alloc_id();
1013 m.emit(OP_F_ADD, &[b.ty_float, c_new, alpha_acc, beta_c]);
1014 m.emit_store(c_ptr, c_new);
1015
1016 m.emit_branch(label_bounds_merge);
1017
1018 m.emit_label(label_bounds_merge);
1019 m.emit_return();
1020 m.emit_function_end();
1021
1022 m.finalize()
1023}
1024
1025fn load_gid_z(m: &mut SpvModule, b: &BaseIds) -> u32 {
1029 let gid_val = m.alloc_id();
1030 m.emit_load(b.ty_v3uint, gid_val, b.var_gid);
1031 let gid_z = m.alloc_id();
1032 m.emit(OP_COMPOSITE_EXTRACT, &[b.ty_uint, gid_z, gid_val, 2]);
1033 gid_z
1034}
1035
1036pub fn batched_gemm_compute_shader() -> Vec<u32> {
1051 let mut m = SpvModule::new();
1052 let b = emit_preamble(&mut m);
1053
1054 let main_fn = m.alloc_id();
1055 let fn_ty = m.alloc_id();
1056 let p_a = m.alloc_id();
1057 let p_b = m.alloc_id();
1058 let p_c = m.alloc_id();
1059 let p_m = m.alloc_id();
1060 let p_n = m.alloc_id();
1061 let p_k = m.alloc_id();
1062 let p_alpha = m.alloc_id();
1063 let p_beta = m.alloc_id();
1064 let p_batch_count = m.alloc_id();
1065 let p_stride_a = m.alloc_id();
1066 let p_stride_b = m.alloc_id();
1067 let p_stride_c = m.alloc_id();
1068
1069 m.emit_type_function(
1072 fn_ty,
1073 b.ty_void,
1074 &[
1075 b.ty_ptr_cross_float,
1076 b.ty_ptr_cross_float,
1077 b.ty_ptr_cross_float,
1078 b.ty_uint,
1079 b.ty_uint,
1080 b.ty_uint,
1081 b.ty_float,
1082 b.ty_float,
1083 b.ty_uint,
1084 b.ty_uint,
1085 b.ty_uint,
1086 b.ty_uint,
1087 ],
1088 );
1089
1090 m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
1091 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
1092
1093 let label_entry = m.alloc_id();
1094 let label_bounds_body = m.alloc_id();
1095 let label_bounds_merge = m.alloc_id();
1096 let label_loop_header = m.alloc_id();
1097 let label_loop_body = m.alloc_id();
1098 let label_loop_continue = m.alloc_id();
1099 let label_loop_merge = m.alloc_id();
1100
1101 m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
1102 m.emit_function_parameter(b.ty_ptr_cross_float, p_a);
1103 m.emit_function_parameter(b.ty_ptr_cross_float, p_b);
1104 m.emit_function_parameter(b.ty_ptr_cross_float, p_c);
1105 m.emit_function_parameter(b.ty_uint, p_m);
1106 m.emit_function_parameter(b.ty_uint, p_n);
1107 m.emit_function_parameter(b.ty_uint, p_k);
1108 m.emit_function_parameter(b.ty_float, p_alpha);
1109 m.emit_function_parameter(b.ty_float, p_beta);
1110 m.emit_function_parameter(b.ty_uint, p_batch_count);
1111 m.emit_function_parameter(b.ty_uint, p_stride_a);
1112 m.emit_function_parameter(b.ty_uint, p_stride_b);
1113 m.emit_function_parameter(b.ty_uint, p_stride_c);
1114 m.emit_label(label_entry);
1115
1116 let gid = load_gid_x(&mut m, &b);
1118 let batch_idx = load_gid_z(&mut m, &b);
1120
1121 let total = m.alloc_id();
1123 m.emit(OP_I_MUL, &[b.ty_uint, total, p_m, p_n]);
1124
1125 let cond1 = m.alloc_id();
1127 m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond1, gid, total]);
1128 let cond2 = m.alloc_id();
1129 m.emit(
1130 OP_U_LESS_THAN,
1131 &[b.ty_bool, cond2, batch_idx, p_batch_count],
1132 );
1133 let cond = m.alloc_id();
1135 m.emit(166, &[b.ty_bool, cond, cond1, cond2]);
1137 m.emit_selection_merge(label_bounds_merge);
1138 m.emit_branch_conditional(cond, label_bounds_body, label_bounds_merge);
1139
1140 m.emit_label(label_bounds_body);
1141
1142 let a_offset = m.alloc_id();
1144 m.emit(OP_I_MUL, &[b.ty_uint, a_offset, batch_idx, p_stride_a]);
1145 let b_offset = m.alloc_id();
1146 m.emit(OP_I_MUL, &[b.ty_uint, b_offset, batch_idx, p_stride_b]);
1147 let c_offset = m.alloc_id();
1148 m.emit(OP_I_MUL, &[b.ty_uint, c_offset, batch_idx, p_stride_c]);
1149
1150 let a_batch = m.alloc_id();
1152 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, a_batch, p_a, a_offset);
1153 let b_batch = m.alloc_id();
1154 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, b_batch, p_b, b_offset);
1155 let c_batch = m.alloc_id();
1156 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, c_batch, p_c, c_offset);
1157
1158 let row = m.alloc_id();
1160 m.emit(OP_U_DIV, &[b.ty_uint, row, gid, p_n]);
1161 let col = m.alloc_id();
1162 m.emit(OP_U_MOD, &[b.ty_uint, col, gid, p_n]);
1163
1164 let var_i = m.alloc_id();
1166 m.emit_variable(b.ty_ptr_func_uint, var_i, STORAGE_CLASS_FUNCTION);
1167 m.emit_store(var_i, b.c_uint_0);
1168 let var_acc = m.alloc_id();
1169 m.emit_variable(b.ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
1170 m.emit_store(var_acc, b.c_float_0);
1171
1172 m.emit_branch(label_loop_header);
1173
1174 m.emit_label(label_loop_header);
1176 let i_val = m.alloc_id();
1177 m.emit_load(b.ty_uint, i_val, var_i);
1178 let loop_cond = m.alloc_id();
1179 m.emit(OP_U_LESS_THAN, &[b.ty_bool, loop_cond, i_val, p_k]);
1180 m.emit_loop_merge(label_loop_merge, label_loop_continue);
1181 m.emit_branch_conditional(loop_cond, label_loop_body, label_loop_merge);
1182
1183 m.emit_label(label_loop_body);
1185
1186 let row_k = m.alloc_id();
1188 m.emit(OP_I_MUL, &[b.ty_uint, row_k, row, p_k]);
1189 let a_idx = m.alloc_id();
1190 m.emit(OP_I_ADD, &[b.ty_uint, a_idx, row_k, i_val]);
1191
1192 let i_n = m.alloc_id();
1194 m.emit(OP_I_MUL, &[b.ty_uint, i_n, i_val, p_n]);
1195 let b_idx = m.alloc_id();
1196 m.emit(OP_I_ADD, &[b.ty_uint, b_idx, i_n, col]);
1197
1198 let a_ptr = m.alloc_id();
1199 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, a_ptr, a_batch, a_idx);
1200 let a_val = m.alloc_id();
1201 m.emit_load(b.ty_float, a_val, a_ptr);
1202
1203 let b_ptr = m.alloc_id();
1204 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, b_ptr, b_batch, b_idx);
1205 let b_val = m.alloc_id();
1206 m.emit_load(b.ty_float, b_val, b_ptr);
1207
1208 let prod = m.alloc_id();
1209 m.emit(OP_F_MUL, &[b.ty_float, prod, a_val, b_val]);
1210 let old_acc = m.alloc_id();
1211 m.emit_load(b.ty_float, old_acc, var_acc);
1212 let new_acc = m.alloc_id();
1213 m.emit(OP_F_ADD, &[b.ty_float, new_acc, old_acc, prod]);
1214 m.emit_store(var_acc, new_acc);
1215
1216 m.emit_branch(label_loop_continue);
1217
1218 m.emit_label(label_loop_continue);
1220 let i_inc = m.alloc_id();
1221 m.emit(OP_I_ADD, &[b.ty_uint, i_inc, i_val, b.c_uint_1]);
1222 m.emit_store(var_i, i_inc);
1223 m.emit_branch(label_loop_header);
1224
1225 m.emit_label(label_loop_merge);
1227
1228 let final_acc = m.alloc_id();
1230 m.emit_load(b.ty_float, final_acc, var_acc);
1231 let alpha_acc = m.alloc_id();
1232 m.emit(OP_F_MUL, &[b.ty_float, alpha_acc, p_alpha, final_acc]);
1233
1234 let c_ptr = m.alloc_id();
1235 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, c_ptr, c_batch, gid);
1236 let c_old = m.alloc_id();
1237 m.emit_load(b.ty_float, c_old, c_ptr);
1238 let beta_c = m.alloc_id();
1239 m.emit(OP_F_MUL, &[b.ty_float, beta_c, p_beta, c_old]);
1240 let c_new = m.alloc_id();
1241 m.emit(OP_F_ADD, &[b.ty_float, c_new, alpha_acc, beta_c]);
1242 m.emit_store(c_ptr, c_new);
1243
1244 m.emit_branch(label_bounds_merge);
1245
1246 m.emit_label(label_bounds_merge);
1247 m.emit_return();
1248 m.emit_function_end();
1249
1250 m.finalize()
1251}
1252
1253pub fn trivial_compute_shader() -> Vec<u32> {
1259 let mut m = SpvModule::new();
1260
1261 let id_main_fn = m.alloc_id();
1262 let id_void = m.alloc_id();
1263 let id_void_fn = m.alloc_id();
1264 let id_label = m.alloc_id();
1265
1266 m.emit_capability(CAPABILITY_SHADER);
1267 m.emit_memory_model(ADDRESSING_MODEL_LOGICAL, MEMORY_MODEL_GLSL450);
1268
1269 let mut entry_words = vec![EXECUTION_MODEL_GLCOMPUTE, id_main_fn];
1270 entry_words.extend(SpvModule::string_words("main"));
1271 m.emit(OP_ENTRY_POINT, &entry_words);
1272
1273 m.emit_execution_mode_local_size(id_main_fn, 1, 1, 1);
1274
1275 m.emit_type_void(id_void);
1276 m.emit_type_function(id_void_fn, id_void, &[]);
1277
1278 m.emit_function(id_void, id_main_fn, FUNCTION_CONTROL_NONE, id_void_fn);
1279 m.emit_label(id_label);
1280 m.emit_return();
1281 m.emit_function_end();
1282
1283 m.finalize()
1284}
1285
1286pub fn trivial_compute_shader_bytes() -> Vec<u8> {
1289 trivial_compute_shader()
1290 .iter()
1291 .flat_map(|w| w.to_ne_bytes())
1292 .collect()
1293}
1294
1295#[cfg(test)]
1298mod tests {
1299 use super::*;
1300
1301 fn check_valid_spirv(words: &[u32]) {
1302 assert!(words.len() >= 5, "too short for SPIR-V header");
1303 assert_eq!(words[0], SPIRV_MAGIC, "bad magic");
1304 assert!(words[3] > 0, "ID bound must be > 0");
1305 assert_eq!(words[4], 0, "schema must be 0");
1306 }
1307
1308 #[test]
1309 fn placeholder_spv_valid_magic() {
1310 let words = trivial_compute_shader();
1311 check_valid_spirv(&words);
1312 }
1313
1314 #[test]
1315 fn placeholder_spv_word_aligned() {
1316 let bytes = trivial_compute_shader_bytes();
1317 assert_eq!(bytes.len() % 4, 0);
1318 }
1319
1320 #[test]
1321 fn placeholder_spv_version_and_schema() {
1322 let words = trivial_compute_shader();
1323 assert!(words[1] >= 0x0001_0000);
1324 assert_eq!(words[4], 0);
1325 }
1326
1327 #[test]
1328 fn placeholder_spv_nonzero_bound() {
1329 let words = trivial_compute_shader();
1330 assert!(words[3] > 0);
1331 }
1332
1333 #[test]
1334 fn spv_module_id_allocation_is_monotonic() {
1335 let mut m = SpvModule::new();
1336 let id1 = m.alloc_id();
1337 let id2 = m.alloc_id();
1338 assert!(id2 > id1);
1339 }
1340
1341 #[test]
1342 fn string_words_null_terminated() {
1343 let words = SpvModule::string_words("abc");
1344 assert!(!words.is_empty());
1345 let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_le_bytes()).collect();
1346 assert_eq!(bytes[0], b'a');
1347 assert_eq!(bytes[1], b'b');
1348 assert_eq!(bytes[2], b'c');
1349 assert_eq!(bytes[3], 0);
1350 }
1351
1352 #[test]
1353 fn string_words_empty_string() {
1354 let words = SpvModule::string_words("");
1355 assert!(!words.is_empty());
1356 let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_le_bytes()).collect();
1357 assert_eq!(bytes[0], 0);
1358 }
1359
1360 #[test]
1361 fn generator_magic_is_level_zero() {
1362 assert_eq!(SPIRV_GENERATOR, 0x000D_0002);
1363 assert_ne!(SPIRV_GENERATOR, 0x000D_0001);
1364 }
1365
1366 #[test]
1369 fn unary_shader_all_ops() {
1370 let ops = [
1371 UnaryOp::Relu,
1372 UnaryOp::Sigmoid,
1373 UnaryOp::Tanh,
1374 UnaryOp::Exp,
1375 UnaryOp::Log,
1376 UnaryOp::Sqrt,
1377 UnaryOp::Abs,
1378 UnaryOp::Neg,
1379 ];
1380 for op in ops {
1381 let words = unary_compute_shader(op);
1382 check_valid_spirv(&words);
1383 }
1384 }
1385
1386 #[test]
1387 fn binary_shader_all_ops() {
1388 let ops = [
1389 BinaryOp::Add,
1390 BinaryOp::Sub,
1391 BinaryOp::Mul,
1392 BinaryOp::Div,
1393 BinaryOp::Max,
1394 BinaryOp::Min,
1395 ];
1396 for op in ops {
1397 let words = binary_compute_shader(op);
1398 check_valid_spirv(&words);
1399 }
1400 }
1401
1402 #[test]
1403 fn reduce_shader_all_ops() {
1404 let ops = [ReduceOp::Sum, ReduceOp::Max, ReduceOp::Min, ReduceOp::Mean];
1405 for op in ops {
1406 let words = reduce_compute_shader(op);
1407 check_valid_spirv(&words);
1408 }
1409 }
1410
1411 #[test]
1412 fn gemm_shader_valid() {
1413 let words = gemm_compute_shader();
1414 check_valid_spirv(&words);
1415 }
1416
1417 #[test]
1418 fn batched_gemm_shader_valid() {
1419 let words = batched_gemm_compute_shader();
1420 check_valid_spirv(&words);
1421 }
1422
1423 #[test]
1424 fn batched_gemm_shader_word_aligned() {
1425 let words = batched_gemm_compute_shader();
1426 let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_ne_bytes()).collect();
1427 assert_eq!(bytes.len() % 4, 0);
1428 }
1429
1430 #[test]
1431 fn batched_gemm_shader_uses_kernel_capability() {
1432 let words = batched_gemm_compute_shader();
1433 let cap_header = (2u32 << 16) | OP_CAPABILITY;
1434 assert_eq!(words[5], cap_header);
1435 assert_eq!(words[6], 6); }
1437
1438 #[test]
1439 fn all_kernel_shaders_word_aligned() {
1440 fn to_bytes(words: &[u32]) -> Vec<u8> {
1441 words.iter().flat_map(|w| w.to_ne_bytes()).collect()
1442 }
1443 assert_eq!(to_bytes(&unary_compute_shader(UnaryOp::Relu)).len() % 4, 0);
1444 assert_eq!(to_bytes(&binary_compute_shader(BinaryOp::Add)).len() % 4, 0);
1445 assert_eq!(to_bytes(&reduce_compute_shader(ReduceOp::Sum)).len() % 4, 0);
1446 assert_eq!(to_bytes(&gemm_compute_shader()).len() % 4, 0);
1447 assert_eq!(to_bytes(&batched_gemm_compute_shader()).len() % 4, 0);
1448 }
1449
1450 #[test]
1451 fn kernel_shaders_use_opencl_memory_model() {
1452 let trivial = trivial_compute_shader();
1455 let unary = unary_compute_shader(UnaryOp::Relu);
1456
1457 let cap_header = (2u32 << 16) | OP_CAPABILITY;
1464 assert_eq!(trivial[5], cap_header);
1465 assert_eq!(trivial[6], CAPABILITY_SHADER);
1466 assert_eq!(unary[5], cap_header);
1467 assert_eq!(unary[6], CAPABILITY_KERNEL);
1468 }
1469}