jit_spirv_impl/
lib.rs

1//! # jit-spirv
2//!
3//! The first parameter is a slice of string that contains the textual shader
4//! source. Other following parameters give you finer control over the generated
5//! code to compile your shaders.
6//!
7//! ## Source Language
8//!
9//! `jit-spirv` currently support three source languages:
10//!
11//! - `glsl`: The shader source is in GLSL (enabled by default);
12//! - `hlsl`: The shader source is in HLSL (enabled by default);
13//! - `wgsl`: The shader source is in WGSL.
14//!
15//! The experimental WGSL support for WebGPU is available when `wgsl` feature is
16//! enabled, but currently you have to compile with a nightly toolchain. Limited
17//! by the `naga` backend, most of the extra parameters won't be effective and
18//! only the first entry point is generated in SPIR-V.
19//!
20//! ## Shader Stages
21//!
22//! The following shader stages are supported:
23//!
24//! - `vert`: Vertex shader;
25//! - `tesc`: Tessellation control shader (Hull shader);
26//! - `tese`: Tessellation evaluation shader (Domain shader);
27//! - `geom`: Geometry shader;
28//! - `frag`: Fragment shader (Pixel shader);
29//! - `comp`: Compute shader;
30//! - `mesh`: (Mesh shading) Mesh shader;
31//! - `task`: (Mesh shading) Task shader;
32//! - `rgen`: (Raytracing) ray-generation shader;
33//! - `rint`: (Raytracing) intersection shader;
34//! - `rahit`: (Raytracing) any-hit shader;
35//! - `rchit`: (Raytracing) closest-hit shader;
36//! - `rmiss`: (Raytracing) miss shader;
37//! - `rcall`: (Raytracing) callable shader;
38//!
39//! ## Specify Entry Function
40//!
41//! By default the compiler seeks for an entry point function named `main`. You
42//! can also explicitly specify the entry function name:
43//!
44//! ```ignore
45//! jit_spirv!(hlsl_source, hlsl, vert, entry="very_main");
46//! ```
47//!
48//! ## Optimization Preference
49//!
50//! To decide how much you want the SPIR-V to be optimized:
51//!
52//! - `min_size`: Optimize for the minimal output size;
53//! - `max_perf`: Optimize for the best performance;
54//! - `no_debug`: Strip off all the debug information (don't do this if you want
55//! to reflect the SPIR-V and get variable names).
56//!
57//! ## Compiler Definition
58//!
59//! You can also define macro substitutions:
60//!
61//! ```ignore
62//! jit_spirv!(glsl_source, vert,
63//!     D USE_LIGHTMAP,
64//!     D LIGHTMAP_COUNT="2");
65//! ```
66//!
67//! You can request a specific version of target environment:
68//! - `vulkan1_0` for Vulkan 1.0 (default, supports SPIR-V 1.0);
69//! - `vulkan1_1` for Vulkan 1.1 (supports SPIR-V 1.3);
70//! - `vulkan1_2` for Vulkan 1.2 (supports SPIR-V 1.5).
71//! - `opengl4_5` for OpenGL 4.5 core profile.
72//! - `webgpu` for WebGPU.
73//!
74//! Of course once you started to use macro is basically means that you are
75//! getting so dynamic that this little crate might not be enough. Then it might
76//! be a good time to build your own shader compilation pipeline!
77//!
78//! ## Descriptor Auto-binding
79//!
80//! If you are just off your work being tooooo tired to specify the descriptor
81//! binding points yourself, you can switch on `auto_bind`:
82//!
83//! ```ignore
84//! jit_spirv!(r#"
85//!     #version 450 core
86//!     uniform sampler2D limap;
87//!     uniform sampler2D emit_map;
88//!     void main() {}
89//! "#, glsl, frag, auto_bind);
90//! ```
91//!
92//! However, if you don't have any automated reflection tool to get the actual
93//! binding points, it's not recommended to use this.
94//!
95//! ## Flip-Y for WebGPU
96//!
97//! If you intend to compile WGSL for a WebGPU backend, `naga` by default
98//! inverts the Y-axis due to the discrepancy in NDC (Normalized Device
99//! Coordinates) between WebGPU and Vulkan. If such correction is undesired, you
100//! can opt out with `no_y_flip`.
101extern crate proc_macro;
102
103mod backends;
104
105#[cfg(not(any(feature = "shaderc", feature = "naga")))]
106compile_error!("no compiler backend enabled; please specify at least one of \
107    the following input source features: `glsl`, `hlsl`, `wgsl`");
108
109use proc_macro::TokenStream;
110use syn::parse::{Parse, ParseStream, Result as ParseResult, Error as ParseError};
111use syn::{parse_macro_input, Ident, LitStr, Token, Expr};
112
113#[derive(Clone, Copy)]
114enum InputSourceLanguage {
115    Unknown,
116    Glsl,
117    Hlsl,
118    Wgsl,
119}
120#[derive(Clone, Copy)]
121enum TargetSpirvVersion {
122    Spirv1_0,
123    #[allow(dead_code)]
124    Spirv1_1,
125    #[allow(dead_code)]
126    Spirv1_2,
127    Spirv1_3,
128    #[allow(dead_code)]
129    Spirv1_4,
130    Spirv1_5,
131}
132#[derive(Clone, Copy)]
133enum TargetEnvironmentType {
134    Vulkan,
135    OpenGL,
136    WebGpu,
137}
138#[derive(Clone, Copy)]
139enum OptimizationLevel {
140    MinSize,
141    MaxPerformance,
142    None,
143}
144#[derive(Clone, Copy)]
145enum ShaderKind {
146    Unknown,
147
148    Vertex,
149    TesselationControl,
150    TesselationEvaluation,
151    Geometry,
152    Fragment,
153    Compute,
154    // Mesh Pipeline
155    Mesh,
156    Task,
157    // Ray-tracing Pipeline
158    RayGeneration,
159    Intersection,
160    AnyHit,
161    ClosestHit,
162    Miss,
163    Callable,
164}
165
166struct ShaderCompilationConfig {
167    path: Option<String>,
168    lang: InputSourceLanguage,
169    incl_dirs: Vec<String>,
170    defs: Vec<(String, Option<String>)>,
171    spv_ver: TargetSpirvVersion,
172    env_ty: TargetEnvironmentType,
173    entry: String,
174    optim_lv: OptimizationLevel,
175    debug: bool,
176    kind: ShaderKind,
177    auto_bind: bool,
178    // Backend specific.
179    #[cfg(feature = "naga")]
180    y_flip: bool,
181}
182impl Default for ShaderCompilationConfig {
183    fn default() -> Self {
184        ShaderCompilationConfig {
185            path: None,
186            lang: InputSourceLanguage::Unknown,
187            incl_dirs: Vec::new(),
188            defs: Vec::new(),
189            spv_ver: TargetSpirvVersion::Spirv1_0,
190            env_ty: TargetEnvironmentType::Vulkan,
191            entry: "main".to_owned(),
192            optim_lv: OptimizationLevel::None,
193            debug: true,
194            kind: ShaderKind::Unknown,
195            auto_bind: false,
196
197            #[cfg(feature = "naga")]
198            y_flip: true,
199        }
200    }
201}
202
203struct JitSpirv(TokenStream);
204
205#[inline]
206fn parse_str(input: &mut ParseStream) -> ParseResult<String> {
207    input.parse::<LitStr>()
208        .map(|x| x.value())
209}
210#[inline]
211fn parse_ident(input: &mut ParseStream) -> ParseResult<String> {
212    input.parse::<Ident>()
213        .map(|x| x.to_string())
214}
215
216fn parse_compile_cfg(
217    input: &mut ParseStream
218) -> ParseResult<ShaderCompilationConfig> {
219    let mut cfg = ShaderCompilationConfig::default();
220    while !input.is_empty() {
221        use syn::Error;
222        // Capture comma and collon; they are for readability.
223        input.parse::<Token![,]>()?;
224        let k = if let Ok(k) = input.parse::<Ident>() { k } else { break };
225        match &k.to_string() as &str {
226            "path" => {
227                input.parse::<Token![,]>()?;
228                cfg.path = Some(parse_str(input)?);
229            },
230
231            "glsl" => cfg.lang = InputSourceLanguage::Glsl,
232            "hlsl" => {
233                cfg.lang = InputSourceLanguage::Hlsl;
234                // HLSL might be illegal if optimization is disabled. Not sure,
235                // `glslangValidator` said this.
236                cfg.optim_lv = OptimizationLevel::MaxPerformance;
237            },
238            "wgsl" => cfg.lang = InputSourceLanguage::Wgsl,
239
240            "vert" => cfg.kind = ShaderKind::Vertex,
241            "tesc" => cfg.kind = ShaderKind::TesselationControl,
242            "tese" => cfg.kind = ShaderKind::TesselationEvaluation,
243            "geom" => cfg.kind = ShaderKind::Geometry,
244            "frag" => cfg.kind = ShaderKind::Fragment,
245            "comp" => cfg.kind = ShaderKind::Compute,
246            "mesh" => cfg.kind = ShaderKind::Mesh,
247            "task" => cfg.kind = ShaderKind::Task,
248            "rgen" => cfg.kind = ShaderKind::RayGeneration,
249            "rint" => cfg.kind = ShaderKind::Intersection,
250            "rahit" => cfg.kind = ShaderKind::AnyHit,
251            "rchit" => cfg.kind = ShaderKind::ClosestHit,
252            "rmiss" => cfg.kind = ShaderKind::Miss,
253            "rcall" => cfg.kind = ShaderKind::Callable,
254
255            "I" => {
256                cfg.incl_dirs.push(parse_str(input)?)
257            },
258            "D" => {
259                let k = parse_ident(input)?;
260                let v = if input.parse::<Token![=]>().is_ok() {
261                    Some(parse_str(input)?)
262                } else { None };
263                cfg.defs.push((k, v));
264            },
265
266            "entry" => {
267                if input.parse::<Token![=]>().is_ok() {
268                    cfg.entry = parse_str(input)?.to_owned();
269                }
270            }
271
272            "min_size" => cfg.optim_lv = OptimizationLevel::MinSize,
273            "max_perf" => cfg.optim_lv = OptimizationLevel::MaxPerformance,
274
275            "no_debug" => cfg.debug = false,
276
277            "vulkan" | "vulkan1_0" => {
278                cfg.env_ty = TargetEnvironmentType::Vulkan;
279                cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
280            },
281            "vulkan1_1" => {
282                cfg.env_ty = TargetEnvironmentType::Vulkan;
283                cfg.spv_ver = TargetSpirvVersion::Spirv1_3;
284            },
285            "vulkan1_2" => {
286                cfg.env_ty = TargetEnvironmentType::Vulkan;
287                cfg.spv_ver = TargetSpirvVersion::Spirv1_5;
288            },
289            "opengl" | "opengl4_5" => {
290                cfg.env_ty = TargetEnvironmentType::OpenGL;
291                cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
292            },
293            "webgpu" => {
294                cfg.env_ty = TargetEnvironmentType::WebGpu;
295                cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
296            }
297
298            "auto_bind" => cfg.auto_bind = true,
299
300            #[cfg(feature = "naga")]
301            "no_y_flip" => cfg.y_flip = false,
302
303            _ => return Err(Error::new(k.span(), "unsupported compilation parameter")),
304        }
305    }
306    Ok(cfg)
307}
308
309fn generate_compile_code(
310    src: &Expr,
311    cfg: &ShaderCompilationConfig,
312) -> Result<proc_macro::TokenStream, String> {
313    use quote::quote;
314    let mut is_valid = false;
315    // This defualt error should not be visible to the users.
316    let mut out = quote!(Err(String::default()));
317    if let Ok(generated_code) = backends::naga::generate_compile_code(src, cfg) {
318        out.extend(quote!(.or_else(#generated_code)));
319        is_valid = true;
320    }
321    if let Ok(generated_code) = backends::shaderc::generate_compile_code(src, cfg) {
322        out.extend(quote!(.or_else(#generated_code)));
323        is_valid = true;
324    }
325    if !is_valid {
326        return Err("cannot find a proper shader compiler backend".to_owned());
327    }
328    Ok(out.into())
329}
330
331impl Parse for JitSpirv {
332    fn parse(mut input: ParseStream) -> ParseResult<Self> {
333        let src = input.parse::<Expr>()?;
334
335        let cfg = parse_compile_cfg(&mut input)?;
336        let tokens = generate_compile_code(&src, &cfg)
337            .map_err(|e| ParseError::new(input.span(), e))?;
338        Ok(JitSpirv(tokens))
339    }
340}
341
342/// Generate shader compilation code to translate GLSL/HLSL/WGSL to SPIR-V
343/// binary word sequence (`Vec<u32>`).
344#[proc_macro]
345pub fn jit_spirv(tokens: TokenStream) -> TokenStream {
346    let JitSpirv(tokens) = parse_macro_input!(tokens as JitSpirv);
347    tokens
348}