1use dynamic::Type;
2use rspirv::{binary::Disassemble, dr::Module};
3use smol_str::SmolStr;
4use std::rc::Rc;
5
6#[derive(Debug, Clone)]
7pub struct SpirvModule {
8 pub(crate) words: Vec<u32>,
9 pub(crate) module: Module,
10}
11
12impl SpirvModule {
13 pub fn words(&self) -> &[u32] {
14 &self.words
15 }
16
17 pub fn into_words(self) -> Vec<u32> {
18 self.words
19 }
20
21 pub fn disassemble(&self) -> String {
22 self.module.disassemble()
23 }
24}
25
26#[derive(Debug, Clone)]
27pub struct Kernel {
28 pub spirv: SpirvModule,
29 pub entry: SmolStr,
30 pub arg_tys: Vec<Type>,
31 pub ret_ty: Type,
32}
33
34#[derive(Debug, Clone)]
35pub struct ExternalFn {
36 pub full_name: SmolStr,
37 pub arg_tys: Vec<Type>,
38 pub ret_ty: Type,
39 pub kind: ExternalFnKind,
40}
41
42#[derive(Debug, Clone)]
43pub enum ExternalFnKind {
44 GlslUnary { float_op: spirv::GlslStd450Op, signed_int_op: Option<spirv::GlslStd450Op> },
45 GlslBinary { float_op: spirv::GlslStd450Op, signed_int_op: spirv::GlslStd450Op, unsigned_int_op: spirv::GlslStd450Op },
46 GlslFloatBinary { op: spirv::GlslStd450Op },
47 GlslFloatTernary { op: spirv::GlslStd450Op },
48 Builtin(BuiltinFn),
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
52pub enum BuiltinFn {
53 GroupId,
54 LocalId,
55 Barrier,
56 AtomicAdd,
57}
58
59impl ExternalFn {
60 pub fn glsl_unary(full_name: impl Into<SmolStr>, arg_ty: Type, ret_ty: Type, float_op: spirv::GlslStd450Op, signed_int_op: Option<spirv::GlslStd450Op>) -> Self {
61 Self { full_name: full_name.into(), arg_tys: vec![arg_ty], ret_ty, kind: ExternalFnKind::GlslUnary { float_op, signed_int_op } }
62 }
63
64 pub fn glsl_binary(full_name: impl Into<SmolStr>, arg_ty: Type, ret_ty: Type, float_op: spirv::GlslStd450Op, signed_int_op: spirv::GlslStd450Op, unsigned_int_op: spirv::GlslStd450Op) -> Self {
65 Self { full_name: full_name.into(), arg_tys: vec![arg_ty.clone(), arg_ty], ret_ty, kind: ExternalFnKind::GlslBinary { float_op, signed_int_op, unsigned_int_op } }
66 }
67
68 pub fn glsl_float_binary(full_name: impl Into<SmolStr>, arg_ty: Type, ret_ty: Type, op: spirv::GlslStd450Op) -> Self {
69 Self { full_name: full_name.into(), arg_tys: vec![arg_ty.clone(), arg_ty], ret_ty, kind: ExternalFnKind::GlslFloatBinary { op } }
70 }
71
72 pub fn glsl_float_ternary(full_name: impl Into<SmolStr>, arg_ty: Type, ret_ty: Type, op: spirv::GlslStd450Op) -> Self {
73 Self { full_name: full_name.into(), arg_tys: vec![arg_ty.clone(), arg_ty.clone(), arg_ty], ret_ty, kind: ExternalFnKind::GlslFloatTernary { op } }
74 }
75
76 pub fn builtin(full_name: impl Into<SmolStr>, arg_tys: Vec<Type>, ret_ty: Type, builtin: BuiltinFn) -> Self {
77 Self { full_name: full_name.into(), arg_tys, ret_ty, kind: ExternalFnKind::Builtin(builtin) }
78 }
79}
80
81pub fn spirv_builtins() -> Vec<ExternalFn> {
82 vec![
83 ExternalFn::builtin("spirv::group_id", vec![], Type::Vec(Rc::new(Type::U32), 3), BuiltinFn::GroupId),
84 ExternalFn::builtin("spirv::local_id", vec![], Type::Vec(Rc::new(Type::U32), 3), BuiltinFn::LocalId),
85 ExternalFn::builtin("spirv::barrier", vec![], Type::Void, BuiltinFn::Barrier),
86 ExternalFn::builtin("spirv::atomic_add", vec![Type::U32, Type::U32], Type::U32, BuiltinFn::AtomicAdd),
87 ExternalFn::glsl_unary("abs", Type::F32, Type::F32, spirv::GlslStd450Op::FAbs, Some(spirv::GlslStd450Op::SAbs)),
88 ExternalFn::glsl_unary("sign", Type::F32, Type::F32, spirv::GlslStd450Op::FSign, Some(spirv::GlslStd450Op::SSign)),
89 ExternalFn::glsl_unary("floor", Type::F32, Type::F32, spirv::GlslStd450Op::Floor, None),
90 ExternalFn::glsl_unary("ceil", Type::F32, Type::F32, spirv::GlslStd450Op::Ceil, None),
91 ExternalFn::glsl_unary("round", Type::F32, Type::F32, spirv::GlslStd450Op::Round, None),
92 ExternalFn::glsl_unary("round_even", Type::F32, Type::F32, spirv::GlslStd450Op::RoundEven, None),
93 ExternalFn::glsl_unary("trunc", Type::F32, Type::F32, spirv::GlslStd450Op::Trunc, None),
94 ExternalFn::glsl_unary("fract", Type::F32, Type::F32, spirv::GlslStd450Op::Fract, None),
95 ExternalFn::glsl_unary("radians", Type::F32, Type::F32, spirv::GlslStd450Op::Radians, None),
96 ExternalFn::glsl_unary("degrees", Type::F32, Type::F32, spirv::GlslStd450Op::Degrees, None),
97 ExternalFn::glsl_unary("sin", Type::F32, Type::F32, spirv::GlslStd450Op::Sin, None),
98 ExternalFn::glsl_unary("cos", Type::F32, Type::F32, spirv::GlslStd450Op::Cos, None),
99 ExternalFn::glsl_unary("tan", Type::F32, Type::F32, spirv::GlslStd450Op::Tan, None),
100 ExternalFn::glsl_unary("asin", Type::F32, Type::F32, spirv::GlslStd450Op::Asin, None),
101 ExternalFn::glsl_unary("acos", Type::F32, Type::F32, spirv::GlslStd450Op::Acos, None),
102 ExternalFn::glsl_unary("atan", Type::F32, Type::F32, spirv::GlslStd450Op::Atan, None),
103 ExternalFn::glsl_unary("sinh", Type::F32, Type::F32, spirv::GlslStd450Op::Sinh, None),
104 ExternalFn::glsl_unary("cosh", Type::F32, Type::F32, spirv::GlslStd450Op::Cosh, None),
105 ExternalFn::glsl_unary("tanh", Type::F32, Type::F32, spirv::GlslStd450Op::Tanh, None),
106 ExternalFn::glsl_unary("asinh", Type::F32, Type::F32, spirv::GlslStd450Op::Asinh, None),
107 ExternalFn::glsl_unary("acosh", Type::F32, Type::F32, spirv::GlslStd450Op::Acosh, None),
108 ExternalFn::glsl_unary("atanh", Type::F32, Type::F32, spirv::GlslStd450Op::Atanh, None),
109 ExternalFn::glsl_unary("exp", Type::F32, Type::F32, spirv::GlslStd450Op::Exp, None),
110 ExternalFn::glsl_unary("log", Type::F32, Type::F32, spirv::GlslStd450Op::Log, None),
111 ExternalFn::glsl_unary("exp2", Type::F32, Type::F32, spirv::GlslStd450Op::Exp2, None),
112 ExternalFn::glsl_unary("log2", Type::F32, Type::F32, spirv::GlslStd450Op::Log2, None),
113 ExternalFn::glsl_unary("sqrt", Type::F32, Type::F32, spirv::GlslStd450Op::Sqrt, None),
114 ExternalFn::glsl_unary("inverse_sqrt", Type::F32, Type::F32, spirv::GlslStd450Op::InverseSqrt, None),
115 ExternalFn::glsl_float_binary("atan2", Type::F32, Type::F32, spirv::GlslStd450Op::Atan2),
116 ExternalFn::glsl_float_binary("pow", Type::F32, Type::F32, spirv::GlslStd450Op::Pow),
117 ExternalFn::glsl_float_binary("step", Type::F32, Type::F32, spirv::GlslStd450Op::Step),
118 ExternalFn::glsl_binary("min", Type::F32, Type::F32, spirv::GlslStd450Op::FMin, spirv::GlslStd450Op::SMin, spirv::GlslStd450Op::UMin),
119 ExternalFn::glsl_binary("max", Type::F32, Type::F32, spirv::GlslStd450Op::FMax, spirv::GlslStd450Op::SMax, spirv::GlslStd450Op::UMax),
120 ExternalFn::glsl_float_ternary("clamp", Type::F32, Type::F32, spirv::GlslStd450Op::FClamp),
121 ExternalFn::glsl_float_ternary("mix", Type::F32, Type::F32, spirv::GlslStd450Op::FMix),
122 ExternalFn::glsl_float_ternary("smoothstep", Type::F32, Type::F32, spirv::GlslStd450Op::SmoothStep),
123 ExternalFn::glsl_float_ternary("fma", Type::F32, Type::F32, spirv::GlslStd450Op::Fma),
124 ]
125}