1use oxicuda_backend::{BinaryOp, ReduceOp, UnaryOp};
17
18pub const SPIRV_MAGIC: u32 = 0x07230203;
22pub const SPIRV_VERSION_1_2: u32 = 0x0001_0200;
24pub const SPIRV_GENERATOR: u32 = 0x000D_0002;
26
27pub(crate) const OP_EXT_INST_IMPORT: u32 = 11;
30pub(crate) const OP_EXT_INST: u32 = 12;
31pub(crate) const OP_MEMORY_MODEL: u32 = 14;
32pub(crate) const OP_ENTRY_POINT: u32 = 15;
33pub(crate) const OP_EXECUTION_MODE: u32 = 16;
34pub(crate) const OP_CAPABILITY: u32 = 17;
35pub(crate) const OP_TYPE_VOID: u32 = 19;
36pub(crate) const OP_TYPE_BOOL: u32 = 20;
37pub(crate) const OP_TYPE_INT: u32 = 21;
38pub(crate) const OP_TYPE_FLOAT: u32 = 22;
39pub(crate) const OP_TYPE_VECTOR: u32 = 23;
40pub(crate) const OP_TYPE_POINTER: u32 = 32;
41pub(crate) const OP_TYPE_FUNCTION: u32 = 33;
42pub(crate) const OP_CONSTANT: u32 = 43;
43pub(crate) const OP_FUNCTION: u32 = 54;
44pub(crate) const OP_FUNCTION_PARAMETER: u32 = 55;
45pub(crate) const OP_FUNCTION_END: u32 = 56;
46pub(crate) const OP_VARIABLE: u32 = 59;
47pub(crate) const OP_LOAD: u32 = 61;
48pub(crate) const OP_STORE: u32 = 62;
49pub(crate) const OP_IN_BOUNDS_PTR_ACCESS_CHAIN: u32 = 70;
50pub(crate) const OP_DECORATE: u32 = 71;
51pub(crate) const OP_COMPOSITE_EXTRACT: u32 = 81;
52pub(crate) const OP_CONVERT_U_TO_F: u32 = 112;
53pub(crate) const OP_F_NEGATE: u32 = 127;
54pub(crate) const OP_I_ADD: u32 = 128;
55pub(crate) const OP_F_ADD: u32 = 129;
56pub(crate) const OP_F_SUB: u32 = 131;
57pub(crate) const OP_I_MUL: u32 = 132;
58pub(crate) const OP_F_MUL: u32 = 133;
59pub(crate) const OP_U_DIV: u32 = 134;
60pub(crate) const OP_F_DIV: u32 = 136;
61pub(crate) const OP_U_MOD: u32 = 137;
62pub(crate) const OP_U_LESS_THAN: u32 = 176;
63pub(crate) const OP_LOOP_MERGE: u32 = 246;
64pub(crate) const OP_SELECTION_MERGE: u32 = 247;
65pub(crate) const OP_LABEL: u32 = 248;
66pub(crate) const OP_BRANCH: u32 = 249;
67pub(crate) const OP_BRANCH_CONDITIONAL: u32 = 250;
68pub(crate) const OP_RETURN: u32 = 253;
69
70const CAPABILITY_SHADER: u32 = 1;
72const CAPABILITY_ADDRESSES: u32 = 4;
73const CAPABILITY_KERNEL: u32 = 6;
74
75const ADDRESSING_MODEL_LOGICAL: u32 = 0;
77const ADDRESSING_MODEL_PHYSICAL64: u32 = 2;
78const MEMORY_MODEL_GLSL450: u32 = 1;
79const MEMORY_MODEL_OPENCL: u32 = 2;
80
81const EXECUTION_MODEL_GLCOMPUTE: u32 = 5;
83pub(crate) const EXECUTION_MODEL_KERNEL: u32 = 6;
84const EXECUTION_MODE_LOCAL_SIZE: u32 = 17;
85
86pub(crate) const FUNCTION_CONTROL_NONE: u32 = 0;
88
89const DECORATION_BUILTIN: u32 = 11;
91
92const BUILTIN_GLOBAL_INVOCATION_ID: u32 = 28;
94
95const STORAGE_CLASS_INPUT: u32 = 1;
97const STORAGE_CLASS_CROSS_WORKGROUP: u32 = 5;
98pub(crate) const STORAGE_CLASS_FUNCTION: u32 = 7;
99
100const SELECTION_CONTROL_NONE: u32 = 0;
102const LOOP_CONTROL_NONE: u32 = 0;
103
104pub(crate) const OPENCL_EXP: u32 = 19;
106const OPENCL_FABS: u32 = 23;
107pub(crate) const OPENCL_FMAX: u32 = 27;
108const OPENCL_FMIN: u32 = 28;
109const OPENCL_LOG: u32 = 37;
110const OPENCL_SQRT: u32 = 61;
111const OPENCL_TANH: u32 = 63;
112
113pub(crate) const WORKGROUP_SIZE: u32 = 256;
115
116pub struct SpvModule {
123 words: Vec<u32>,
124 id_bound: u32,
126}
127
128impl SpvModule {
129 pub fn new() -> Self {
131 let words = vec![SPIRV_MAGIC, SPIRV_VERSION_1_2, SPIRV_GENERATOR, 0, 0];
132 Self { words, id_bound: 1 }
133 }
134
135 pub fn alloc_id(&mut self) -> u32 {
137 let id = self.id_bound;
138 self.id_bound += 1;
139 id
140 }
141
142 pub fn emit(&mut self, opcode: u32, operands: &[u32]) {
144 let word_count = (1 + operands.len()) as u32;
145 self.words.push((word_count << 16) | opcode);
146 self.words.extend_from_slice(operands);
147 }
148
149 pub fn string_words(s: &str) -> Vec<u32> {
151 let bytes = s.as_bytes();
152 let padded_len = (bytes.len() + 4) & !3;
153 let mut out = vec![0u32; padded_len / 4];
154 for (i, &b) in bytes.iter().enumerate() {
155 out[i / 4] |= (b as u32) << ((i % 4) * 8);
156 }
157 out
158 }
159
160 pub fn finalize(mut self) -> Vec<u32> {
162 self.words[3] = self.id_bound;
163 self.words
164 }
165
166 pub(crate) fn emit_capability(&mut self, cap: u32) {
169 self.emit(OP_CAPABILITY, &[cap]);
170 }
171
172 pub(crate) fn emit_ext_inst_import(&mut self, id: u32, name: &str) {
173 let mut ops = vec![id];
174 ops.extend(Self::string_words(name));
175 self.emit(OP_EXT_INST_IMPORT, &ops);
176 }
177
178 pub(crate) fn emit_memory_model(&mut self, addressing: u32, memory: u32) {
179 self.emit(OP_MEMORY_MODEL, &[addressing, memory]);
180 }
181
182 pub(crate) fn emit_entry_point(
183 &mut self,
184 model: u32,
185 func_id: u32,
186 name: &str,
187 interfaces: &[u32],
188 ) {
189 let mut ops = vec![model, func_id];
190 ops.extend(Self::string_words(name));
191 ops.extend_from_slice(interfaces);
192 self.emit(OP_ENTRY_POINT, &ops);
193 }
194
195 pub(crate) fn emit_execution_mode_local_size(&mut self, func_id: u32, x: u32, y: u32, z: u32) {
196 self.emit(
197 OP_EXECUTION_MODE,
198 &[func_id, EXECUTION_MODE_LOCAL_SIZE, x, y, z],
199 );
200 }
201
202 pub(crate) fn emit_decorate(&mut self, target: u32, decoration: u32, operands: &[u32]) {
203 let mut ops = vec![target, decoration];
204 ops.extend_from_slice(operands);
205 self.emit(OP_DECORATE, &ops);
206 }
207
208 pub(crate) fn emit_type_void(&mut self, id: u32) {
209 self.emit(OP_TYPE_VOID, &[id]);
210 }
211
212 pub(crate) fn emit_type_bool(&mut self, id: u32) {
213 self.emit(OP_TYPE_BOOL, &[id]);
214 }
215
216 pub(crate) fn emit_type_int(&mut self, id: u32, width: u32, signedness: u32) {
217 self.emit(OP_TYPE_INT, &[id, width, signedness]);
218 }
219
220 pub(crate) fn emit_type_float(&mut self, id: u32, width: u32) {
221 self.emit(OP_TYPE_FLOAT, &[id, width]);
222 }
223
224 pub(crate) fn emit_type_vector(&mut self, id: u32, component: u32, count: u32) {
225 self.emit(OP_TYPE_VECTOR, &[id, component, count]);
226 }
227
228 pub(crate) fn emit_type_pointer(&mut self, id: u32, storage_class: u32, pointee: u32) {
229 self.emit(OP_TYPE_POINTER, &[id, storage_class, pointee]);
230 }
231
232 pub(crate) fn emit_type_function(&mut self, id: u32, return_type: u32, params: &[u32]) {
233 let mut ops = vec![id, return_type];
234 ops.extend_from_slice(params);
235 self.emit(OP_TYPE_FUNCTION, &ops);
236 }
237
238 pub(crate) fn emit_constant_u32(&mut self, ty: u32, id: u32, value: u32) {
239 self.emit(OP_CONSTANT, &[ty, id, value]);
240 }
241
242 pub(crate) fn emit_constant_f32(&mut self, ty: u32, id: u32, value: f32) {
243 self.emit(OP_CONSTANT, &[ty, id, value.to_bits()]);
244 }
245
246 pub(crate) fn emit_variable(&mut self, ty: u32, id: u32, storage_class: u32) {
247 self.emit(OP_VARIABLE, &[ty, id, storage_class]);
248 }
249
250 pub(crate) fn emit_load(&mut self, result_ty: u32, result: u32, pointer: u32) {
251 self.emit(OP_LOAD, &[result_ty, result, pointer]);
252 }
253
254 pub(crate) fn emit_store(&mut self, pointer: u32, value: u32) {
255 self.emit(OP_STORE, &[pointer, value]);
256 }
257
258 pub(crate) fn emit_in_bounds_ptr_access_chain(
259 &mut self,
260 result_ty: u32,
261 result: u32,
262 base: u32,
263 element: u32,
264 ) {
265 self.emit(
266 OP_IN_BOUNDS_PTR_ACCESS_CHAIN,
267 &[result_ty, result, base, element],
268 );
269 }
270
271 pub(crate) fn emit_function(&mut self, result_ty: u32, result: u32, control: u32, fn_ty: u32) {
272 self.emit(OP_FUNCTION, &[result_ty, result, control, fn_ty]);
273 }
274
275 pub(crate) fn emit_function_parameter(&mut self, result_ty: u32, result: u32) {
276 self.emit(OP_FUNCTION_PARAMETER, &[result_ty, result]);
277 }
278
279 pub(crate) fn emit_label(&mut self, id: u32) {
280 self.emit(OP_LABEL, &[id]);
281 }
282
283 pub(crate) fn emit_return(&mut self) {
284 self.emit(OP_RETURN, &[]);
285 }
286
287 pub(crate) fn emit_function_end(&mut self) {
288 self.emit(OP_FUNCTION_END, &[]);
289 }
290
291 pub(crate) fn emit_branch(&mut self, target: u32) {
292 self.emit(OP_BRANCH, &[target]);
293 }
294
295 pub(crate) fn emit_branch_conditional(&mut self, cond: u32, true_label: u32, false_label: u32) {
296 self.emit(OP_BRANCH_CONDITIONAL, &[cond, true_label, false_label]);
297 }
298
299 pub(crate) fn emit_selection_merge(&mut self, merge_label: u32) {
300 self.emit(OP_SELECTION_MERGE, &[merge_label, SELECTION_CONTROL_NONE]);
301 }
302
303 pub(crate) fn emit_loop_merge(&mut self, merge_label: u32, continue_label: u32) {
304 self.emit(
305 OP_LOOP_MERGE,
306 &[merge_label, continue_label, LOOP_CONTROL_NONE],
307 );
308 }
309
310 pub(crate) fn emit_opencl_ext(
311 &mut self,
312 ext_id: u32,
313 result_ty: u32,
314 result: u32,
315 inst: u32,
316 args: &[u32],
317 ) {
318 let mut ops = vec![result_ty, result, ext_id, inst];
319 ops.extend_from_slice(args);
320 self.emit(OP_EXT_INST, &ops);
321 }
322}
323
324impl Default for SpvModule {
325 fn default() -> Self {
326 Self::new()
327 }
328}
329
330pub(crate) struct BaseIds {
334 pub(crate) ty_void: u32,
335 pub(crate) ty_bool: u32,
336 pub(crate) ty_uint: u32,
337 pub(crate) ty_float: u32,
338 #[allow(dead_code)]
339 pub(crate) ty_v3uint: u32,
340 #[allow(dead_code)]
341 pub(crate) ty_fn_void: u32,
342 #[allow(dead_code)]
343 pub(crate) ty_ptr_input_v3uint: u32,
344 pub(crate) ty_ptr_cross_float: u32,
345 pub(crate) ty_ptr_func_float: u32,
346 pub(crate) ty_ptr_func_uint: u32,
347 pub(crate) c_uint_0: u32,
348 pub(crate) c_uint_1: u32,
349 pub(crate) c_float_0: u32,
350 pub(crate) c_float_1: u32,
351 pub(crate) var_gid: u32,
352 pub(crate) opencl_ext: u32,
353}
354
355pub(crate) fn emit_preamble(m: &mut SpvModule) -> BaseIds {
361 let ty_void = m.alloc_id();
362 let ty_bool = m.alloc_id();
363 let ty_uint = m.alloc_id();
364 let ty_float = m.alloc_id();
365 let ty_v3uint = m.alloc_id();
366 let ty_fn_void = m.alloc_id();
367 let ty_ptr_input_v3uint = m.alloc_id();
368 let ty_ptr_cross_float = m.alloc_id();
369 let ty_ptr_func_float = m.alloc_id();
370 let ty_ptr_func_uint = m.alloc_id();
371 let c_uint_0 = m.alloc_id();
372 let c_uint_1 = m.alloc_id();
373 let c_float_0 = m.alloc_id();
374 let c_float_1 = m.alloc_id();
375 let var_gid = m.alloc_id();
376 let opencl_ext = m.alloc_id();
377
378 m.emit_capability(CAPABILITY_KERNEL);
380 m.emit_capability(CAPABILITY_ADDRESSES);
381
382 m.emit_ext_inst_import(opencl_ext, "OpenCL.std");
384
385 m.emit_memory_model(ADDRESSING_MODEL_PHYSICAL64, MEMORY_MODEL_OPENCL);
387
388 m.emit_decorate(var_gid, DECORATION_BUILTIN, &[BUILTIN_GLOBAL_INVOCATION_ID]);
393
394 m.emit_type_void(ty_void);
396 m.emit_type_bool(ty_bool);
397 m.emit_type_int(ty_uint, 32, 0);
398 m.emit_type_float(ty_float, 32);
399 m.emit_type_vector(ty_v3uint, ty_uint, 3);
400 m.emit_type_function(ty_fn_void, ty_void, &[]);
401 m.emit_type_pointer(ty_ptr_input_v3uint, STORAGE_CLASS_INPUT, ty_v3uint);
402 m.emit_type_pointer(ty_ptr_cross_float, STORAGE_CLASS_CROSS_WORKGROUP, ty_float);
403 m.emit_type_pointer(ty_ptr_func_float, STORAGE_CLASS_FUNCTION, ty_float);
404 m.emit_type_pointer(ty_ptr_func_uint, STORAGE_CLASS_FUNCTION, ty_uint);
405
406 m.emit_constant_u32(ty_uint, c_uint_0, 0);
408 m.emit_constant_u32(ty_uint, c_uint_1, 1);
409 m.emit_constant_f32(ty_float, c_float_0, 0.0);
410 m.emit_constant_f32(ty_float, c_float_1, 1.0);
411
412 m.emit_variable(ty_ptr_input_v3uint, var_gid, STORAGE_CLASS_INPUT);
414
415 BaseIds {
416 ty_void,
417 ty_bool,
418 ty_uint,
419 ty_float,
420 ty_v3uint,
421 ty_fn_void,
422 ty_ptr_input_v3uint,
423 ty_ptr_cross_float,
424 ty_ptr_func_float,
425 ty_ptr_func_uint,
426 c_uint_0,
427 c_uint_1,
428 c_float_0,
429 c_float_1,
430 var_gid,
431 opencl_ext,
432 }
433}
434
435pub(crate) fn load_gid_x(m: &mut SpvModule, b: &BaseIds) -> u32 {
437 let gid_val = m.alloc_id();
438 m.emit_load(b.ty_v3uint, gid_val, b.var_gid);
439 let gid_x = m.alloc_id();
440 m.emit(OP_COMPOSITE_EXTRACT, &[b.ty_uint, gid_x, gid_val, 0]);
441 gid_x
442}
443
444pub fn unary_compute_shader(op: UnaryOp) -> Vec<u32> {
450 let mut m = SpvModule::new();
451 let b = emit_preamble(&mut m);
452
453 let main_fn = m.alloc_id();
454 let fn_ty = m.alloc_id();
455 let p_input = m.alloc_id();
456 let p_output = m.alloc_id();
457 let p_count = m.alloc_id();
458
459 m.emit_type_function(
461 fn_ty,
462 b.ty_void,
463 &[b.ty_ptr_cross_float, b.ty_ptr_cross_float, b.ty_uint],
464 );
465
466 m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
468 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
469
470 let label_entry = m.alloc_id();
472 let label_body = m.alloc_id();
473 let label_merge = m.alloc_id();
474
475 m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
477 m.emit_function_parameter(b.ty_ptr_cross_float, p_input);
478 m.emit_function_parameter(b.ty_ptr_cross_float, p_output);
479 m.emit_function_parameter(b.ty_uint, p_count);
480 m.emit_label(label_entry);
481
482 let gid = load_gid_x(&mut m, &b);
483
484 let cond = m.alloc_id();
486 m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, p_count]);
487 m.emit_selection_merge(label_merge);
488 m.emit_branch_conditional(cond, label_body, label_merge);
489
490 m.emit_label(label_body);
491
492 let inp_ptr = m.alloc_id();
494 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, inp_ptr, p_input, gid);
495 let inp_val = m.alloc_id();
496 m.emit_load(b.ty_float, inp_val, inp_ptr);
497
498 let result = emit_unary_op(&mut m, &b, op, inp_val);
499
500 let out_ptr = m.alloc_id();
502 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_output, gid);
503 m.emit_store(out_ptr, result);
504
505 m.emit_branch(label_merge);
506
507 m.emit_label(label_merge);
508 m.emit_return();
509 m.emit_function_end();
510
511 m.finalize()
512}
513
514fn emit_unary_op(m: &mut SpvModule, b: &BaseIds, op: UnaryOp, x: u32) -> u32 {
516 let result = m.alloc_id();
517 match op {
518 UnaryOp::Relu => {
519 m.emit_opencl_ext(
520 b.opencl_ext,
521 b.ty_float,
522 result,
523 OPENCL_FMAX,
524 &[b.c_float_0, x],
525 );
526 }
527 UnaryOp::Sigmoid => {
528 let neg_x = m.alloc_id();
529 m.emit(OP_F_NEGATE, &[b.ty_float, neg_x, x]);
530 let exp_neg_x = m.alloc_id();
531 m.emit_opencl_ext(b.opencl_ext, b.ty_float, exp_neg_x, OPENCL_EXP, &[neg_x]);
532 let one_plus = m.alloc_id();
533 m.emit(OP_F_ADD, &[b.ty_float, one_plus, b.c_float_1, exp_neg_x]);
534 m.emit(OP_F_DIV, &[b.ty_float, result, b.c_float_1, one_plus]);
535 }
536 UnaryOp::Tanh => {
537 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_TANH, &[x]);
538 }
539 UnaryOp::Exp => {
540 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_EXP, &[x]);
541 }
542 UnaryOp::Log => {
543 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_LOG, &[x]);
544 }
545 UnaryOp::Sqrt => {
546 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_SQRT, &[x]);
547 }
548 UnaryOp::Abs => {
549 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_FABS, &[x]);
550 }
551 UnaryOp::Neg => {
552 m.emit(OP_F_NEGATE, &[b.ty_float, result, x]);
553 }
554 }
555 result
556}
557
558pub fn binary_compute_shader(op: BinaryOp) -> Vec<u32> {
565 let mut m = SpvModule::new();
566 let b = emit_preamble(&mut m);
567
568 let main_fn = m.alloc_id();
569 let fn_ty = m.alloc_id();
570 let p_a = m.alloc_id();
571 let p_b = m.alloc_id();
572 let p_out = m.alloc_id();
573 let p_count = m.alloc_id();
574
575 m.emit_type_function(
577 fn_ty,
578 b.ty_void,
579 &[
580 b.ty_ptr_cross_float,
581 b.ty_ptr_cross_float,
582 b.ty_ptr_cross_float,
583 b.ty_uint,
584 ],
585 );
586
587 m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
588 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
589
590 let label_entry = m.alloc_id();
591 let label_body = m.alloc_id();
592 let label_merge = m.alloc_id();
593
594 m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
595 m.emit_function_parameter(b.ty_ptr_cross_float, p_a);
596 m.emit_function_parameter(b.ty_ptr_cross_float, p_b);
597 m.emit_function_parameter(b.ty_ptr_cross_float, p_out);
598 m.emit_function_parameter(b.ty_uint, p_count);
599 m.emit_label(label_entry);
600
601 let gid = load_gid_x(&mut m, &b);
602
603 let cond = m.alloc_id();
604 m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, p_count]);
605 m.emit_selection_merge(label_merge);
606 m.emit_branch_conditional(cond, label_body, label_merge);
607
608 m.emit_label(label_body);
609
610 let a_ptr = m.alloc_id();
611 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, a_ptr, p_a, gid);
612 let a_val = m.alloc_id();
613 m.emit_load(b.ty_float, a_val, a_ptr);
614
615 let b_ptr = m.alloc_id();
616 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, b_ptr, p_b, gid);
617 let b_val = m.alloc_id();
618 m.emit_load(b.ty_float, b_val, b_ptr);
619
620 let result = emit_binary_op(&mut m, &b, op, a_val, b_val);
621
622 let out_ptr = m.alloc_id();
623 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_out, gid);
624 m.emit_store(out_ptr, result);
625
626 m.emit_branch(label_merge);
627
628 m.emit_label(label_merge);
629 m.emit_return();
630 m.emit_function_end();
631
632 m.finalize()
633}
634
635fn emit_binary_op(m: &mut SpvModule, b: &BaseIds, op: BinaryOp, lhs: u32, rhs: u32) -> u32 {
636 let result = m.alloc_id();
637 match op {
638 BinaryOp::Add => m.emit(OP_F_ADD, &[b.ty_float, result, lhs, rhs]),
639 BinaryOp::Sub => m.emit(OP_F_SUB, &[b.ty_float, result, lhs, rhs]),
640 BinaryOp::Mul => m.emit(OP_F_MUL, &[b.ty_float, result, lhs, rhs]),
641 BinaryOp::Div => m.emit(OP_F_DIV, &[b.ty_float, result, lhs, rhs]),
642 BinaryOp::Max => {
643 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_FMAX, &[lhs, rhs]);
644 }
645 BinaryOp::Min => {
646 m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_FMIN, &[lhs, rhs]);
647 }
648 }
649 result
650}
651
652pub fn reduce_compute_shader(op: ReduceOp) -> Vec<u32> {
661 let mut m = SpvModule::new();
662 let b = emit_preamble(&mut m);
663
664 let main_fn = m.alloc_id();
665 let fn_ty = m.alloc_id();
666 let p_input = m.alloc_id();
667 let p_output = m.alloc_id();
668 let p_outer = m.alloc_id();
669 let p_reduce = m.alloc_id();
670 let p_inner = m.alloc_id();
671
672 m.emit_type_function(
674 fn_ty,
675 b.ty_void,
676 &[
677 b.ty_ptr_cross_float,
678 b.ty_ptr_cross_float,
679 b.ty_uint,
680 b.ty_uint,
681 b.ty_uint,
682 ],
683 );
684
685 m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
686 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
687
688 let label_entry = m.alloc_id();
689 let label_bounds_body = m.alloc_id();
690 let label_bounds_merge = m.alloc_id();
691 let label_loop_header = m.alloc_id();
692 let label_loop_body = m.alloc_id();
693 let label_loop_continue = m.alloc_id();
694 let label_loop_merge = m.alloc_id();
695
696 m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
697 m.emit_function_parameter(b.ty_ptr_cross_float, p_input);
698 m.emit_function_parameter(b.ty_ptr_cross_float, p_output);
699 m.emit_function_parameter(b.ty_uint, p_outer);
700 m.emit_function_parameter(b.ty_uint, p_reduce);
701 m.emit_function_parameter(b.ty_uint, p_inner);
702 m.emit_label(label_entry);
703
704 let gid = load_gid_x(&mut m, &b);
705
706 let total_output = m.alloc_id();
708 m.emit(OP_I_MUL, &[b.ty_uint, total_output, p_outer, p_inner]);
709
710 let cond_bounds = m.alloc_id();
712 m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond_bounds, gid, total_output]);
713 m.emit_selection_merge(label_bounds_merge);
714 m.emit_branch_conditional(cond_bounds, label_bounds_body, label_bounds_merge);
715
716 m.emit_label(label_bounds_body);
717
718 let outer_idx = m.alloc_id();
720 m.emit(OP_U_DIV, &[b.ty_uint, outer_idx, gid, p_inner]);
721 let inner_idx = m.alloc_id();
722 m.emit(OP_U_MOD, &[b.ty_uint, inner_idx, gid, p_inner]);
723
724 let t1 = m.alloc_id();
726 m.emit(OP_I_MUL, &[b.ty_uint, t1, outer_idx, p_reduce]);
727 let t2 = m.alloc_id();
728 m.emit(OP_I_MUL, &[b.ty_uint, t2, t1, p_inner]);
729 let base_idx = m.alloc_id();
730 m.emit(OP_I_ADD, &[b.ty_uint, base_idx, t2, inner_idx]);
731
732 let var_i = m.alloc_id();
734 m.emit_variable(b.ty_ptr_func_uint, var_i, STORAGE_CLASS_FUNCTION);
735 m.emit_store(var_i, b.c_uint_0);
736
737 let var_acc = m.alloc_id();
739 m.emit_variable(b.ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
740 let init_val = match op {
741 ReduceOp::Sum | ReduceOp::Mean => b.c_float_0,
742 ReduceOp::Max => {
743 let neg_inf = m.alloc_id();
744 m.emit_constant_f32(b.ty_float, neg_inf, f32::NEG_INFINITY);
745 neg_inf
746 }
747 ReduceOp::Min => {
748 let pos_inf = m.alloc_id();
749 m.emit_constant_f32(b.ty_float, pos_inf, f32::INFINITY);
750 pos_inf
751 }
752 };
753 m.emit_store(var_acc, init_val);
754
755 m.emit_branch(label_loop_header);
756
757 m.emit_label(label_loop_header);
759 let i_val = m.alloc_id();
760 m.emit_load(b.ty_uint, i_val, var_i);
761 let loop_cond = m.alloc_id();
762 m.emit(OP_U_LESS_THAN, &[b.ty_bool, loop_cond, i_val, p_reduce]);
763 m.emit_loop_merge(label_loop_merge, label_loop_continue);
764 m.emit_branch_conditional(loop_cond, label_loop_body, label_loop_merge);
765
766 m.emit_label(label_loop_body);
768
769 let i_times_inner = m.alloc_id();
771 m.emit(OP_I_MUL, &[b.ty_uint, i_times_inner, i_val, p_inner]);
772 let input_idx = m.alloc_id();
773 m.emit(OP_I_ADD, &[b.ty_uint, input_idx, base_idx, i_times_inner]);
774
775 let inp_ptr = m.alloc_id();
776 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, inp_ptr, p_input, input_idx);
777 let inp_val = m.alloc_id();
778 m.emit_load(b.ty_float, inp_val, inp_ptr);
779
780 let acc_val = m.alloc_id();
781 m.emit_load(b.ty_float, acc_val, var_acc);
782
783 let new_acc = m.alloc_id();
784 match op {
785 ReduceOp::Sum | ReduceOp::Mean => {
786 m.emit(OP_F_ADD, &[b.ty_float, new_acc, acc_val, inp_val]);
787 }
788 ReduceOp::Max => {
789 m.emit_opencl_ext(
790 b.opencl_ext,
791 b.ty_float,
792 new_acc,
793 OPENCL_FMAX,
794 &[acc_val, inp_val],
795 );
796 }
797 ReduceOp::Min => {
798 m.emit_opencl_ext(
799 b.opencl_ext,
800 b.ty_float,
801 new_acc,
802 OPENCL_FMIN,
803 &[acc_val, inp_val],
804 );
805 }
806 }
807 m.emit_store(var_acc, new_acc);
808
809 m.emit_branch(label_loop_continue);
810
811 m.emit_label(label_loop_continue);
813 let i_inc = m.alloc_id();
814 m.emit(OP_I_ADD, &[b.ty_uint, i_inc, i_val, b.c_uint_1]);
815 m.emit_store(var_i, i_inc);
816 m.emit_branch(label_loop_header);
817
818 m.emit_label(label_loop_merge);
820
821 let final_acc = m.alloc_id();
822 m.emit_load(b.ty_float, final_acc, var_acc);
823
824 let store_val = if op == ReduceOp::Mean {
825 let reduce_f = m.alloc_id();
826 m.emit(OP_CONVERT_U_TO_F, &[b.ty_float, reduce_f, p_reduce]);
827 let mean_val = m.alloc_id();
828 m.emit(OP_F_DIV, &[b.ty_float, mean_val, final_acc, reduce_f]);
829 mean_val
830 } else {
831 final_acc
832 };
833
834 let out_ptr = m.alloc_id();
835 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_output, gid);
836 m.emit_store(out_ptr, store_val);
837
838 m.emit_branch(label_bounds_merge);
839
840 m.emit_label(label_bounds_merge);
841 m.emit_return();
842 m.emit_function_end();
843
844 m.finalize()
845}
846
847pub fn gemm_compute_shader() -> Vec<u32> {
857 let mut m = SpvModule::new();
858 let b = emit_preamble(&mut m);
859
860 let main_fn = m.alloc_id();
861 let fn_ty = m.alloc_id();
862 let p_a = m.alloc_id();
863 let p_b = m.alloc_id();
864 let p_c = m.alloc_id();
865 let p_m = m.alloc_id();
866 let p_n = m.alloc_id();
867 let p_k = m.alloc_id();
868 let p_alpha = m.alloc_id();
869 let p_beta = m.alloc_id();
870
871 m.emit_type_function(
873 fn_ty,
874 b.ty_void,
875 &[
876 b.ty_ptr_cross_float,
877 b.ty_ptr_cross_float,
878 b.ty_ptr_cross_float,
879 b.ty_uint,
880 b.ty_uint,
881 b.ty_uint,
882 b.ty_float,
883 b.ty_float,
884 ],
885 );
886
887 m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
888 m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
889
890 let label_entry = m.alloc_id();
891 let label_bounds_body = m.alloc_id();
892 let label_bounds_merge = m.alloc_id();
893 let label_loop_header = m.alloc_id();
894 let label_loop_body = m.alloc_id();
895 let label_loop_continue = m.alloc_id();
896 let label_loop_merge = m.alloc_id();
897
898 m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
899 m.emit_function_parameter(b.ty_ptr_cross_float, p_a);
900 m.emit_function_parameter(b.ty_ptr_cross_float, p_b);
901 m.emit_function_parameter(b.ty_ptr_cross_float, p_c);
902 m.emit_function_parameter(b.ty_uint, p_m);
903 m.emit_function_parameter(b.ty_uint, p_n);
904 m.emit_function_parameter(b.ty_uint, p_k);
905 m.emit_function_parameter(b.ty_float, p_alpha);
906 m.emit_function_parameter(b.ty_float, p_beta);
907 m.emit_label(label_entry);
908
909 let gid = load_gid_x(&mut m, &b);
910
911 let total = m.alloc_id();
913 m.emit(OP_I_MUL, &[b.ty_uint, total, p_m, p_n]);
914
915 let cond = m.alloc_id();
917 m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, total]);
918 m.emit_selection_merge(label_bounds_merge);
919 m.emit_branch_conditional(cond, label_bounds_body, label_bounds_merge);
920
921 m.emit_label(label_bounds_body);
922
923 let row = m.alloc_id();
925 m.emit(OP_U_DIV, &[b.ty_uint, row, gid, p_n]);
926 let col = m.alloc_id();
927 m.emit(OP_U_MOD, &[b.ty_uint, col, gid, p_n]);
928
929 let var_i = m.alloc_id();
931 m.emit_variable(b.ty_ptr_func_uint, var_i, STORAGE_CLASS_FUNCTION);
932 m.emit_store(var_i, b.c_uint_0);
933 let var_acc = m.alloc_id();
934 m.emit_variable(b.ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
935 m.emit_store(var_acc, b.c_float_0);
936
937 m.emit_branch(label_loop_header);
938
939 m.emit_label(label_loop_header);
941 let i_val = m.alloc_id();
942 m.emit_load(b.ty_uint, i_val, var_i);
943 let loop_cond = m.alloc_id();
944 m.emit(OP_U_LESS_THAN, &[b.ty_bool, loop_cond, i_val, p_k]);
945 m.emit_loop_merge(label_loop_merge, label_loop_continue);
946 m.emit_branch_conditional(loop_cond, label_loop_body, label_loop_merge);
947
948 m.emit_label(label_loop_body);
950
951 let row_k = m.alloc_id();
953 m.emit(OP_I_MUL, &[b.ty_uint, row_k, row, p_k]);
954 let a_idx = m.alloc_id();
955 m.emit(OP_I_ADD, &[b.ty_uint, a_idx, row_k, i_val]);
956
957 let i_n = m.alloc_id();
959 m.emit(OP_I_MUL, &[b.ty_uint, i_n, i_val, p_n]);
960 let b_idx = m.alloc_id();
961 m.emit(OP_I_ADD, &[b.ty_uint, b_idx, i_n, col]);
962
963 let a_ptr = m.alloc_id();
964 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, a_ptr, p_a, a_idx);
965 let a_val = m.alloc_id();
966 m.emit_load(b.ty_float, a_val, a_ptr);
967
968 let b_ptr = m.alloc_id();
969 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, b_ptr, p_b, b_idx);
970 let b_val = m.alloc_id();
971 m.emit_load(b.ty_float, b_val, b_ptr);
972
973 let prod = m.alloc_id();
974 m.emit(OP_F_MUL, &[b.ty_float, prod, a_val, b_val]);
975 let old_acc = m.alloc_id();
976 m.emit_load(b.ty_float, old_acc, var_acc);
977 let new_acc = m.alloc_id();
978 m.emit(OP_F_ADD, &[b.ty_float, new_acc, old_acc, prod]);
979 m.emit_store(var_acc, new_acc);
980
981 m.emit_branch(label_loop_continue);
982
983 m.emit_label(label_loop_continue);
985 let i_inc = m.alloc_id();
986 m.emit(OP_I_ADD, &[b.ty_uint, i_inc, i_val, b.c_uint_1]);
987 m.emit_store(var_i, i_inc);
988 m.emit_branch(label_loop_header);
989
990 m.emit_label(label_loop_merge);
992
993 let final_acc = m.alloc_id();
995 m.emit_load(b.ty_float, final_acc, var_acc);
996 let alpha_acc = m.alloc_id();
997 m.emit(OP_F_MUL, &[b.ty_float, alpha_acc, p_alpha, final_acc]);
998
999 let c_ptr = m.alloc_id();
1000 m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, c_ptr, p_c, gid);
1001 let c_old = m.alloc_id();
1002 m.emit_load(b.ty_float, c_old, c_ptr);
1003 let beta_c = m.alloc_id();
1004 m.emit(OP_F_MUL, &[b.ty_float, beta_c, p_beta, c_old]);
1005 let c_new = m.alloc_id();
1006 m.emit(OP_F_ADD, &[b.ty_float, c_new, alpha_acc, beta_c]);
1007 m.emit_store(c_ptr, c_new);
1008
1009 m.emit_branch(label_bounds_merge);
1010
1011 m.emit_label(label_bounds_merge);
1012 m.emit_return();
1013 m.emit_function_end();
1014
1015 m.finalize()
1016}
1017
1018pub fn trivial_compute_shader() -> Vec<u32> {
1024 let mut m = SpvModule::new();
1025
1026 let id_main_fn = m.alloc_id();
1027 let id_void = m.alloc_id();
1028 let id_void_fn = m.alloc_id();
1029 let id_label = m.alloc_id();
1030
1031 m.emit_capability(CAPABILITY_SHADER);
1032 m.emit_memory_model(ADDRESSING_MODEL_LOGICAL, MEMORY_MODEL_GLSL450);
1033
1034 let mut entry_words = vec![EXECUTION_MODEL_GLCOMPUTE, id_main_fn];
1035 entry_words.extend(SpvModule::string_words("main"));
1036 m.emit(OP_ENTRY_POINT, &entry_words);
1037
1038 m.emit_execution_mode_local_size(id_main_fn, 1, 1, 1);
1039
1040 m.emit_type_void(id_void);
1041 m.emit_type_function(id_void_fn, id_void, &[]);
1042
1043 m.emit_function(id_void, id_main_fn, FUNCTION_CONTROL_NONE, id_void_fn);
1044 m.emit_label(id_label);
1045 m.emit_return();
1046 m.emit_function_end();
1047
1048 m.finalize()
1049}
1050
1051pub fn trivial_compute_shader_bytes() -> Vec<u8> {
1054 trivial_compute_shader()
1055 .iter()
1056 .flat_map(|w| w.to_ne_bytes())
1057 .collect()
1058}
1059
1060#[cfg(test)]
1063mod tests {
1064 use super::*;
1065
1066 fn check_valid_spirv(words: &[u32]) {
1067 assert!(words.len() >= 5, "too short for SPIR-V header");
1068 assert_eq!(words[0], SPIRV_MAGIC, "bad magic");
1069 assert!(words[3] > 0, "ID bound must be > 0");
1070 assert_eq!(words[4], 0, "schema must be 0");
1071 }
1072
1073 #[test]
1074 fn placeholder_spv_valid_magic() {
1075 let words = trivial_compute_shader();
1076 check_valid_spirv(&words);
1077 }
1078
1079 #[test]
1080 fn placeholder_spv_word_aligned() {
1081 let bytes = trivial_compute_shader_bytes();
1082 assert_eq!(bytes.len() % 4, 0);
1083 }
1084
1085 #[test]
1086 fn placeholder_spv_version_and_schema() {
1087 let words = trivial_compute_shader();
1088 assert!(words[1] >= 0x0001_0000);
1089 assert_eq!(words[4], 0);
1090 }
1091
1092 #[test]
1093 fn placeholder_spv_nonzero_bound() {
1094 let words = trivial_compute_shader();
1095 assert!(words[3] > 0);
1096 }
1097
1098 #[test]
1099 fn spv_module_id_allocation_is_monotonic() {
1100 let mut m = SpvModule::new();
1101 let id1 = m.alloc_id();
1102 let id2 = m.alloc_id();
1103 assert!(id2 > id1);
1104 }
1105
1106 #[test]
1107 fn string_words_null_terminated() {
1108 let words = SpvModule::string_words("abc");
1109 assert!(!words.is_empty());
1110 let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_le_bytes()).collect();
1111 assert_eq!(bytes[0], b'a');
1112 assert_eq!(bytes[1], b'b');
1113 assert_eq!(bytes[2], b'c');
1114 assert_eq!(bytes[3], 0);
1115 }
1116
1117 #[test]
1118 fn string_words_empty_string() {
1119 let words = SpvModule::string_words("");
1120 assert!(!words.is_empty());
1121 let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_le_bytes()).collect();
1122 assert_eq!(bytes[0], 0);
1123 }
1124
1125 #[test]
1126 fn generator_magic_is_level_zero() {
1127 assert_eq!(SPIRV_GENERATOR, 0x000D_0002);
1128 assert_ne!(SPIRV_GENERATOR, 0x000D_0001);
1129 }
1130
1131 #[test]
1134 fn unary_shader_all_ops() {
1135 let ops = [
1136 UnaryOp::Relu,
1137 UnaryOp::Sigmoid,
1138 UnaryOp::Tanh,
1139 UnaryOp::Exp,
1140 UnaryOp::Log,
1141 UnaryOp::Sqrt,
1142 UnaryOp::Abs,
1143 UnaryOp::Neg,
1144 ];
1145 for op in ops {
1146 let words = unary_compute_shader(op);
1147 check_valid_spirv(&words);
1148 }
1149 }
1150
1151 #[test]
1152 fn binary_shader_all_ops() {
1153 let ops = [
1154 BinaryOp::Add,
1155 BinaryOp::Sub,
1156 BinaryOp::Mul,
1157 BinaryOp::Div,
1158 BinaryOp::Max,
1159 BinaryOp::Min,
1160 ];
1161 for op in ops {
1162 let words = binary_compute_shader(op);
1163 check_valid_spirv(&words);
1164 }
1165 }
1166
1167 #[test]
1168 fn reduce_shader_all_ops() {
1169 let ops = [ReduceOp::Sum, ReduceOp::Max, ReduceOp::Min, ReduceOp::Mean];
1170 for op in ops {
1171 let words = reduce_compute_shader(op);
1172 check_valid_spirv(&words);
1173 }
1174 }
1175
1176 #[test]
1177 fn gemm_shader_valid() {
1178 let words = gemm_compute_shader();
1179 check_valid_spirv(&words);
1180 }
1181
1182 #[test]
1183 fn all_kernel_shaders_word_aligned() {
1184 fn to_bytes(words: &[u32]) -> Vec<u8> {
1185 words.iter().flat_map(|w| w.to_ne_bytes()).collect()
1186 }
1187 assert_eq!(to_bytes(&unary_compute_shader(UnaryOp::Relu)).len() % 4, 0);
1188 assert_eq!(to_bytes(&binary_compute_shader(BinaryOp::Add)).len() % 4, 0);
1189 assert_eq!(to_bytes(&reduce_compute_shader(ReduceOp::Sum)).len() % 4, 0);
1190 assert_eq!(to_bytes(&gemm_compute_shader()).len() % 4, 0);
1191 }
1192
1193 #[test]
1194 fn kernel_shaders_use_opencl_memory_model() {
1195 let trivial = trivial_compute_shader();
1198 let unary = unary_compute_shader(UnaryOp::Relu);
1199
1200 let cap_header = (2u32 << 16) | OP_CAPABILITY;
1207 assert_eq!(trivial[5], cap_header);
1208 assert_eq!(trivial[6], CAPABILITY_SHADER);
1209 assert_eq!(unary[5], cap_header);
1210 assert_eq!(unary[6], CAPABILITY_KERNEL);
1211 }
1212}