inline_spirv/
lib.rs

1//! # inline-spirv
2//!
3//! The first string is always your shader path or the source code, depending on
4//! the macro you use. Other following parameters give you finer control over
5//! the compilation process.
6//!
7//! ## Source Language
8//!
9//! `inline-spirv` currently support three source languages:
10//!
11//! - `spvasm`: The shader source is in [SPIR-V assembly](https://github.com/KhronosGroup/SPIRV-Tools/blob/main/docs/syntax.md) (always there);
12//! - `glsl`: The shader source is in GLSL (enabled by default);
13//! - `hlsl`: The shader source is in HLSL (enabled by default);
14//! - `wgsl`: The shader source is in WGSL.
15//!
16//! The experimental WGSL support for WebGPU is available when `wgsl` feature is
17//! enabled, but currently you have to compile with a nightly toolchain. Limited
18//! by the `naga` backend, most of the extra parameters won't be effective and
19//! only the first entry point is generated in SPIR-V.
20//!
21//! ## Shader Stages
22//!
23//! The following shader stages are supported:
24//!
25//! - `vert`: Vertex shader;
26//! - `tesc`: Tessellation control shader (Hull shader);
27//! - `tese`: Tessellation evaluation shader (Domain shader);
28//! - `geom`: Geometry shader;
29//! - `frag`: Fragment shader (Pixel shader);
30//! - `comp`: Compute shader;
31//! - `mesh`: (Mesh shading) Mesh shader;
32//! - `task`: (Mesh shading) Task shader;
33//! - `rgen`: (Raytracing) ray-generation shader;
34//! - `rint`: (Raytracing) intersection shader;
35//! - `rahit`: (Raytracing) any-hit shader;
36//! - `rchit`: (Raytracing) closest-hit shader;
37//! - `rmiss`: (Raytracing) miss shader;
38//! - `rcall`: (Raytracing) callable shader;
39//!
40//! ## Specify Entry Function
41//!
42//! By default the compiler seeks for an entry point function named `main`. You
43//! can also explicitly specify the entry function name:
44//!
45//! ```ignore
46//! include_spirv!("path/to/shader.hlsl", hlsl, vert, entry="very_main");
47//! ```
48//!
49//! ## Optimization Preference
50//!
51//! To decide how much you want the SPIR-V to be optimized:
52//!
53//! - `min_size`: Optimize for the minimal output size;
54//! - `max_perf`: Optimize for the best performance;
55//! - `no_debug`: Strip off all the debug information (don't do this if you want
56//! to reflect the SPIR-V and get variable names).
57//!
58//! ## Include External Source
59//!
60//! You can use `#include "x.h"` to include a file relative to the shader source
61//! file (you cannot use this in inline source); or you can use `#include <x.h>`
62//! to include a file relative to any of your provided include directories
63//! (searched in order). To specify a include directory:
64//!
65//! ```ignore
66//! include_spirv!("path/to/shader.glsl", vert,
67//!     I "path/to/shader-headers/",
68//!     I "path/to/also-shader-headers/");
69//! ```
70//! 
71//! ## Include SPIR-V Binary
72//!
73//! You may also want to inline precompiled SPIR-V binaries if you already have
74//! your pipeline set up. To do so, you can use `include_spirv_bytes!`:
75//!
76//! ```ignore
77//! include_spirv!("path/to/shader.spv");
78//! ```
79//!
80//! Note that all compile arguments are ignored in this case, since there is no
81//! compilation.
82//!
83//! ## Compiler Definition
84//!
85//! You can also define macro substitutions:
86//!
87//! ```ignore
88//! include_spirv!("path/to/shader.glsl", vert,
89//!     D USE_LIGHTMAP,
90//!     D LIGHTMAP_COUNT="2");
91//! ```
92//!
93//! You can request a specific version of target environment:
94//! - `vulkan1_0` for Vulkan 1.0 (default, supports SPIR-V 1.0);
95//! - `vulkan1_1` for Vulkan 1.1 (supports SPIR-V 1.3);
96//! - `vulkan1_2` for Vulkan 1.2 (supports SPIR-V 1.5).
97//! - `opengl4_5` for OpenGL 4.5 core profile.
98//! - `webgpu` for WebGPU.
99//!
100//! Of course once you started to use macro is basically means that you are
101//! getting so dynamic that this little crate might not be enough. Then it might
102//! be a good time to build your own shader compilation pipeline!
103//!
104//! ## Descriptor Auto-binding
105//!
106//! If you are just off your work being tooooo tired to specify the descriptor
107//! binding points yourself, you can switch on `auto_bind`:
108//!
109//! ```ignore
110//! inline_spirv!(r#"
111//!     #version 450 core
112//!     uniform sampler2D limap;
113//!     uniform sampler2D emit_map;
114//!     void main() {}
115//! "#, glsl, frag, auto_bind);
116//! ```
117//!
118//! However, if you don't have any automated reflection tool to get the actual
119//! binding points, it's not recommended to use this.
120//!
121//! ## Flip-Y for WebGPU
122//!
123//! If you intend to compile WGSL for a WebGPU backend, `naga` by default
124//! inverts the Y-axis due to the discrepancy in NDC (Normalized Device
125//! Coordinates) between WebGPU and Vulkan. If such correction is undesired, you
126//! can opt out with `no_y_flip`.
127//!
128//! ## Tips
129//!
130//! The macro can be verbose especially you have a bunch of `#include`s, so
131//! please be aware of that you can alias and define a more customized macro for
132//! yourself:
133//!
134//! ```ignore
135//! use inline_spirv::include_spirv as include_spirv_raw;
136//!
137//! macro_rules! include_spirv {
138//!     ($path:expr, $stage:ident) => {
139//!         include_spirv_raw!(
140//!             $path,
141//!             $stage, hlsl,
142//!             entry="my_entry_pt",
143//!             D VERBOSE_DEFINITION,
144//!             D ANOTHER_VERBOSE_DEFINITION="verbose definition substitution",
145//!             I "long/path/to/include/directory",
146//!         )
147//!     }
148//! }
149//!
150//! // ...
151//! let vert: &[u32] = include_spirv!("examples/demo/assets/demo.hlsl", vert);
152//! ```
153extern crate proc_macro;
154use std::convert::TryInto;
155use std::path::{Path, PathBuf};
156
157mod backends;
158
159use proc_macro::TokenStream;
160use quote::quote;
161use syn::parse::{Parse, ParseStream, Result as ParseResult, Error as ParseError};
162use syn::{parse_macro_input, Ident, LitStr, Token};
163
164#[derive(Clone, Copy, PartialEq, Eq)]
165enum InputSourceLanguage {
166    Unknown,
167    Glsl,
168    Hlsl,
169    Wgsl,
170    Spvasm,
171}
172#[derive(Clone, Copy, PartialEq, Eq)]
173enum TargetSpirvVersion {
174    Spirv1_0,
175    Spirv1_1,
176    Spirv1_2,
177    Spirv1_3,
178    Spirv1_4,
179    Spirv1_5,
180    Spirv1_6,
181}
182#[derive(Clone, Copy, PartialEq, Eq)]
183enum TargetEnvironmentType {
184    Vulkan,
185    OpenGL,
186    WebGpu,
187}
188#[derive(Clone, Copy, PartialEq, Eq)]
189enum OptimizationLevel {
190    MinSize,
191    MaxPerformance,
192    None,
193}
194#[derive(Clone, Copy, PartialEq, Eq)]
195enum ShaderKind {
196    Unknown,
197
198    Vertex,
199    TesselationControl,
200    TesselationEvaluation,
201    Geometry,
202    Fragment,
203    Compute,
204    // Mesh Pipeline
205    Mesh,
206    Task,
207    // Ray-tracing Pipeline
208    RayGeneration,
209    Intersection,
210    AnyHit,
211    ClosestHit,
212    Miss,
213    Callable,
214}
215
216struct ShaderCompilationConfig {
217    lang: InputSourceLanguage,
218    incl_dirs: Vec<PathBuf>,
219    defs: Vec<(String, Option<String>)>,
220    spv_ver: TargetSpirvVersion,
221    env_ty: TargetEnvironmentType,
222    entry: String,
223    optim_lv: OptimizationLevel,
224    debug: bool,
225    kind: ShaderKind,
226    auto_bind: bool,
227    // Backend specific.
228    #[cfg(feature = "naga")]
229    y_flip: bool,
230}
231impl Default for ShaderCompilationConfig {
232    fn default() -> Self {
233        ShaderCompilationConfig {
234            lang: InputSourceLanguage::Unknown,
235            incl_dirs: Vec::new(),
236            defs: Vec::new(),
237            spv_ver: TargetSpirvVersion::Spirv1_0,
238            env_ty: TargetEnvironmentType::Vulkan,
239            entry: "main".to_owned(),
240            optim_lv: OptimizationLevel::None,
241            debug: true,
242            kind: ShaderKind::Unknown,
243            auto_bind: false,
244
245            #[cfg(feature = "naga")]
246            y_flip: true,
247        }
248    }
249}
250
251struct CompilationFeedback {
252    spv: Vec<u32>,
253    dep_paths: Vec<String>,
254}
255struct InlineShaderSource(CompilationFeedback);
256struct IncludedShaderSource(CompilationFeedback);
257
258#[inline]
259fn get_base_dir() -> PathBuf {
260    let base_dir = std::env::var("CARGO_MANIFEST_DIR")
261        .expect("`inline-spirv` can only be used in build time");
262    PathBuf::from(base_dir)
263}
264#[inline]
265fn parse_str(input: &mut ParseStream) -> ParseResult<String> {
266    input.parse::<LitStr>()
267        .map(|x| x.value())
268}
269#[inline]
270fn parse_ident(input: &mut ParseStream) -> ParseResult<String> {
271    input.parse::<Ident>()
272        .map(|x| x.to_string())
273}
274
275fn parse_compile_cfg(
276    input: &mut ParseStream
277) -> ParseResult<ShaderCompilationConfig> {
278    let mut cfg = ShaderCompilationConfig::default();
279    while !input.is_empty() {
280        use syn::Error;
281        // Capture comma and collon; they are for readability.
282        input.parse::<Token![,]>()?;
283        let k = if let Ok(k) = input.parse::<Ident>() { k } else { break };
284        match &k.to_string() as &str {
285            "glsl" => cfg.lang = InputSourceLanguage::Glsl,
286            "hlsl" => {
287                cfg.lang = InputSourceLanguage::Hlsl;
288                // HLSL might be illegal if optimization is disabled. Not sure,
289                // `glslangValidator` said this.
290                cfg.optim_lv = OptimizationLevel::MaxPerformance;
291            },
292            "wgsl" => cfg.lang = InputSourceLanguage::Wgsl,
293            "spvasm" => cfg.lang = InputSourceLanguage::Spvasm,
294
295            "vert" => cfg.kind = ShaderKind::Vertex,
296            "tesc" => cfg.kind = ShaderKind::TesselationControl,
297            "tese" => cfg.kind = ShaderKind::TesselationEvaluation,
298            "geom" => cfg.kind = ShaderKind::Geometry,
299            "frag" => cfg.kind = ShaderKind::Fragment,
300            "comp" => cfg.kind = ShaderKind::Compute,
301            "mesh" => cfg.kind = ShaderKind::Mesh,
302            "task" => cfg.kind = ShaderKind::Task,
303            "rgen" => cfg.kind = ShaderKind::RayGeneration,
304            "rint" => cfg.kind = ShaderKind::Intersection,
305            "rahit" => cfg.kind = ShaderKind::AnyHit,
306            "rchit" => cfg.kind = ShaderKind::ClosestHit,
307            "rmiss" => cfg.kind = ShaderKind::Miss,
308            "rcall" => cfg.kind = ShaderKind::Callable,
309
310            "I" => {
311                cfg.incl_dirs.push(PathBuf::from(parse_str(input)?))
312            },
313            "D" => {
314                let k = parse_ident(input)?;
315                let v = if input.parse::<Token![=]>().is_ok() {
316                    Some(parse_str(input)?)
317                } else { None };
318                cfg.defs.push((k, v));
319            },
320
321            "entry" => {
322                if input.parse::<Token![=]>().is_ok() {
323                    cfg.entry = parse_str(input)?.to_owned();
324                }
325            }
326
327            "min_size" => cfg.optim_lv = OptimizationLevel::MinSize,
328            "max_perf" => cfg.optim_lv = OptimizationLevel::MaxPerformance,
329
330            "no_debug" => cfg.debug = false,
331
332            "vulkan" | "vulkan1_0" => {
333                cfg.env_ty = TargetEnvironmentType::Vulkan;
334                cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
335            },
336            "vulkan1_1" => {
337                cfg.env_ty = TargetEnvironmentType::Vulkan;
338                cfg.spv_ver = TargetSpirvVersion::Spirv1_3;
339            },
340            "vulkan1_2" => {
341                cfg.env_ty = TargetEnvironmentType::Vulkan;
342                cfg.spv_ver = TargetSpirvVersion::Spirv1_5;
343            },
344            "opengl" | "opengl4_5" => {
345                cfg.env_ty = TargetEnvironmentType::OpenGL;
346                cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
347            },
348            "webgpu" => {
349                cfg.env_ty = TargetEnvironmentType::WebGpu;
350                cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
351            }
352
353            "spirq1_0" => cfg.spv_ver = TargetSpirvVersion::Spirv1_0,
354            "spirq1_1" => cfg.spv_ver = TargetSpirvVersion::Spirv1_1,
355            "spirq1_2" => cfg.spv_ver = TargetSpirvVersion::Spirv1_2,
356            "spirq1_3" => cfg.spv_ver = TargetSpirvVersion::Spirv1_3,
357            "spirq1_4" => cfg.spv_ver = TargetSpirvVersion::Spirv1_4,
358            "spirq1_5" => cfg.spv_ver = TargetSpirvVersion::Spirv1_5,
359            "spirq1_6" => cfg.spv_ver = TargetSpirvVersion::Spirv1_6,
360
361            "auto_bind" => cfg.auto_bind = true,
362
363            #[cfg(feature = "naga")]
364            "no_y_flip" => cfg.y_flip = false,
365
366            _ => return Err(Error::new(k.span(), "unsupported compilation parameter")),
367        }
368    }
369    Ok(cfg)
370}
371
372fn compile(
373    src: &str,
374    path: Option<&str>,
375    cfg: &ShaderCompilationConfig,
376) -> Result<CompilationFeedback, String> {
377    match backends::spirq_spvasm::compile(src, path, cfg) {
378        Ok(x) => return Ok(x),
379        Err(e) if e != "unsupported source language" => return Err(e),
380        _ => {}
381    }
382    #[cfg(feature = "shaderc")]
383    match backends::shaderc::compile(src, path, cfg) {
384        Ok(x) => return Ok(x),
385        Err(e) if e != "unsupported source language" => return Err(e),
386        _ => {}
387    }
388    #[cfg(feature = "naga")]
389    match backends::naga::compile(src, path, cfg) {
390        Ok(x) => return Ok(x),
391        Err(e) if e != "unsupported source language" => return Err(e),
392        _ => {}
393    }
394    Err("no supported backend found".to_owned())
395}
396
397fn build_spirv_binary(path: &Path) -> Option<Vec<u32>> {
398    use std::fs::File;
399    use std::io::Read;
400    let mut buf = Vec::new();
401    if let Ok(mut f) = File::open(&path) {
402        if buf.len() & 3 != 0 {
403            // Misaligned input.
404            return None;
405        }
406        f.read_to_end(&mut buf).ok()?;
407    }
408
409    let out = buf.chunks_exact(4)
410        .map(|x| x.try_into().unwrap())
411        .map(match buf[0] {
412            0x03 => u32::from_le_bytes,
413            0x07 => u32::from_be_bytes,
414            _ => return None,
415        })
416        .collect::<Vec<u32>>();
417
418    Some(out)
419}
420
421impl Parse for IncludedShaderSource {
422    fn parse(mut input: ParseStream) -> ParseResult<Self> {
423        use std::ffi::OsStr;
424        let path_lit = input.parse::<LitStr>()?;
425        let path = Path::new(&get_base_dir())
426            .join(&path_lit.value());
427
428        if !path.exists() || !path.is_file() {
429            return Err(ParseError::new(path_lit.span(),
430                format!("{path} is not a valid source file", path=path_lit.value())));
431        }
432
433        let is_spirv = path.is_file() && path.extension() == Some(OsStr::new("spv"));
434        let feedback = if is_spirv {
435            let spv = build_spirv_binary(&path)
436                .ok_or_else(|| syn::Error::new(path_lit.span(), "invalid spirv"))?;
437            CompilationFeedback {
438                spv,
439                dep_paths: vec![],
440            }
441        } else {
442            let src = std::fs::read_to_string(&path)
443                .map_err(|e| syn::Error::new(path_lit.span(), e))?;
444            let cfg = parse_compile_cfg(&mut input)?;
445            compile(&src, Some(path.to_string_lossy().as_ref()), &cfg)
446                .map_err(|e| ParseError::new(input.span(), e))?
447        };
448        let rv = IncludedShaderSource(feedback);
449        Ok(rv)
450    }
451}
452impl Parse for InlineShaderSource {
453    fn parse(mut input: ParseStream) -> ParseResult<Self> {
454        let src = parse_str(&mut input)?;
455        let cfg = parse_compile_cfg(&mut input)?;
456        let feedback = compile(&src, None, &cfg)
457            .map_err(|e| ParseError::new(input.span(), e))?;
458        let rv = InlineShaderSource(feedback);
459        Ok(rv)
460    }
461}
462
463fn gen_token_stream(feedback: CompilationFeedback) -> TokenStream {
464    let CompilationFeedback { spv, dep_paths } = feedback;
465    (quote! {
466        {
467            { #(let _ = include_bytes!(#dep_paths);)* }
468            &[#(#spv),*]
469        }
470    }).into()
471}
472
473/// Compile inline shader source and embed the SPIR-V binary word sequence.
474/// Returns a `&'static [u32]`.
475#[proc_macro]
476pub fn inline_spirv(tokens: TokenStream) -> TokenStream {
477    let InlineShaderSource(feedback) = parse_macro_input!(tokens as InlineShaderSource);
478    gen_token_stream(feedback)
479}
480/// Compile external shader source and embed the SPIR-V binary word sequence.
481/// Returns a `&'static [u32]`.
482#[proc_macro]
483pub fn include_spirv(tokens: TokenStream) -> TokenStream {
484    let IncludedShaderSource(feedback) = parse_macro_input!(tokens as IncludedShaderSource);
485    gen_token_stream(feedback)
486}