extern crate proc_macro;
use std::convert::TryInto;
use std::path::{Path, PathBuf};
mod backends;
use proc_macro::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream, Result as ParseResult, Error as ParseError};
use syn::{parse_macro_input, Ident, LitStr, Token};
#[derive(Clone, Copy, PartialEq, Eq)]
enum InputSourceLanguage {
Unknown,
Glsl,
Hlsl,
Wgsl,
Spvasm,
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum TargetSpirvVersion {
Spirv1_0,
Spirv1_1,
Spirv1_2,
Spirv1_3,
Spirv1_4,
Spirv1_5,
Spirv1_6,
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum TargetEnvironmentType {
Vulkan,
OpenGL,
WebGpu,
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum OptimizationLevel {
MinSize,
MaxPerformance,
None,
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum ShaderKind {
Unknown,
Vertex,
TesselationControl,
TesselationEvaluation,
Geometry,
Fragment,
Compute,
Mesh,
Task,
RayGeneration,
Intersection,
AnyHit,
ClosestHit,
Miss,
Callable,
}
struct ShaderCompilationConfig {
lang: InputSourceLanguage,
incl_dirs: Vec<PathBuf>,
defs: Vec<(String, Option<String>)>,
spv_ver: TargetSpirvVersion,
env_ty: TargetEnvironmentType,
entry: String,
optim_lv: OptimizationLevel,
debug: bool,
kind: ShaderKind,
auto_bind: bool,
#[cfg(feature = "naga")]
y_flip: bool,
}
impl Default for ShaderCompilationConfig {
fn default() -> Self {
ShaderCompilationConfig {
lang: InputSourceLanguage::Unknown,
incl_dirs: Vec::new(),
defs: Vec::new(),
spv_ver: TargetSpirvVersion::Spirv1_0,
env_ty: TargetEnvironmentType::Vulkan,
entry: "main".to_owned(),
optim_lv: OptimizationLevel::None,
debug: true,
kind: ShaderKind::Unknown,
auto_bind: false,
#[cfg(feature = "naga")]
y_flip: true,
}
}
}
struct CompilationFeedback {
spv: Vec<u32>,
dep_paths: Vec<String>,
}
struct InlineShaderSource(CompilationFeedback);
struct IncludedShaderSource(CompilationFeedback);
#[inline]
fn get_base_dir() -> PathBuf {
let base_dir = std::env::var("CARGO_MANIFEST_DIR")
.expect("`inline-spirv` can only be used in build time");
PathBuf::from(base_dir)
}
#[inline]
fn parse_str(input: &mut ParseStream) -> ParseResult<String> {
input.parse::<LitStr>()
.map(|x| x.value())
}
#[inline]
fn parse_ident(input: &mut ParseStream) -> ParseResult<String> {
input.parse::<Ident>()
.map(|x| x.to_string())
}
fn parse_compile_cfg(
input: &mut ParseStream
) -> ParseResult<ShaderCompilationConfig> {
let mut cfg = ShaderCompilationConfig::default();
while !input.is_empty() {
use syn::Error;
input.parse::<Token![,]>()?;
let k = if let Ok(k) = input.parse::<Ident>() { k } else { break };
match &k.to_string() as &str {
"glsl" => cfg.lang = InputSourceLanguage::Glsl,
"hlsl" => {
cfg.lang = InputSourceLanguage::Hlsl;
cfg.optim_lv = OptimizationLevel::MaxPerformance;
},
"wgsl" => cfg.lang = InputSourceLanguage::Wgsl,
"spvasm" => cfg.lang = InputSourceLanguage::Spvasm,
"vert" => cfg.kind = ShaderKind::Vertex,
"tesc" => cfg.kind = ShaderKind::TesselationControl,
"tese" => cfg.kind = ShaderKind::TesselationEvaluation,
"geom" => cfg.kind = ShaderKind::Geometry,
"frag" => cfg.kind = ShaderKind::Fragment,
"comp" => cfg.kind = ShaderKind::Compute,
"mesh" => cfg.kind = ShaderKind::Mesh,
"task" => cfg.kind = ShaderKind::Task,
"rgen" => cfg.kind = ShaderKind::RayGeneration,
"rint" => cfg.kind = ShaderKind::Intersection,
"rahit" => cfg.kind = ShaderKind::AnyHit,
"rchit" => cfg.kind = ShaderKind::ClosestHit,
"rmiss" => cfg.kind = ShaderKind::Miss,
"rcall" => cfg.kind = ShaderKind::Callable,
"I" => {
cfg.incl_dirs.push(PathBuf::from(parse_str(input)?))
},
"D" => {
let k = parse_ident(input)?;
let v = if input.parse::<Token![=]>().is_ok() {
Some(parse_str(input)?)
} else { None };
cfg.defs.push((k, v));
},
"entry" => {
if input.parse::<Token![=]>().is_ok() {
cfg.entry = parse_str(input)?.to_owned();
}
}
"min_size" => cfg.optim_lv = OptimizationLevel::MinSize,
"max_perf" => cfg.optim_lv = OptimizationLevel::MaxPerformance,
"no_debug" => cfg.debug = false,
"vulkan" | "vulkan1_0" => {
cfg.env_ty = TargetEnvironmentType::Vulkan;
cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
},
"vulkan1_1" => {
cfg.env_ty = TargetEnvironmentType::Vulkan;
cfg.spv_ver = TargetSpirvVersion::Spirv1_3;
},
"vulkan1_2" => {
cfg.env_ty = TargetEnvironmentType::Vulkan;
cfg.spv_ver = TargetSpirvVersion::Spirv1_5;
},
"opengl" | "opengl4_5" => {
cfg.env_ty = TargetEnvironmentType::OpenGL;
cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
},
"webgpu" => {
cfg.env_ty = TargetEnvironmentType::WebGpu;
cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
}
"spirq1_0" => cfg.spv_ver = TargetSpirvVersion::Spirv1_0,
"spirq1_1" => cfg.spv_ver = TargetSpirvVersion::Spirv1_1,
"spirq1_2" => cfg.spv_ver = TargetSpirvVersion::Spirv1_2,
"spirq1_3" => cfg.spv_ver = TargetSpirvVersion::Spirv1_3,
"spirq1_4" => cfg.spv_ver = TargetSpirvVersion::Spirv1_4,
"spirq1_5" => cfg.spv_ver = TargetSpirvVersion::Spirv1_5,
"spirq1_6" => cfg.spv_ver = TargetSpirvVersion::Spirv1_6,
"auto_bind" => cfg.auto_bind = true,
#[cfg(feature = "naga")]
"no_y_flip" => cfg.y_flip = false,
_ => return Err(Error::new(k.span(), "unsupported compilation parameter")),
}
}
Ok(cfg)
}
fn compile(
src: &str,
path: Option<&str>,
cfg: &ShaderCompilationConfig,
) -> Result<CompilationFeedback, String> {
match backends::spirq_spvasm::compile(src, path, cfg) {
Ok(x) => return Ok(x),
Err(e) if e != "unsupported source language" => return Err(e),
_ => {}
}
#[cfg(feature = "shaderc")]
match backends::shaderc::compile(src, path, cfg) {
Ok(x) => return Ok(x),
Err(e) if e != "unsupported source language" => return Err(e),
_ => {}
}
#[cfg(feature = "naga")]
match backends::naga::compile(src, path, cfg) {
Ok(x) => return Ok(x),
Err(e) if e != "unsupported source language" => return Err(e),
_ => {}
}
Err("no supported backend found".to_owned())
}
fn build_spirv_binary(path: &Path) -> Option<Vec<u32>> {
use std::fs::File;
use std::io::Read;
let mut buf = Vec::new();
if let Ok(mut f) = File::open(&path) {
if buf.len() & 3 != 0 {
return None;
}
f.read_to_end(&mut buf).ok()?;
}
let out = buf.chunks_exact(4)
.map(|x| x.try_into().unwrap())
.map(match buf[0] {
0x03 => u32::from_le_bytes,
0x07 => u32::from_be_bytes,
_ => return None,
})
.collect::<Vec<u32>>();
Some(out)
}
impl Parse for IncludedShaderSource {
fn parse(mut input: ParseStream) -> ParseResult<Self> {
use std::ffi::OsStr;
let path_lit = input.parse::<LitStr>()?;
let path = Path::new(&get_base_dir())
.join(&path_lit.value());
if !path.exists() || !path.is_file() {
return Err(ParseError::new(path_lit.span(),
format!("{path} is not a valid source file", path=path_lit.value())));
}
let is_spirv = path.is_file() && path.extension() == Some(OsStr::new("spv"));
let feedback = if is_spirv {
let spv = build_spirv_binary(&path)
.ok_or_else(|| syn::Error::new(path_lit.span(), "invalid spirv"))?;
CompilationFeedback {
spv,
dep_paths: vec![],
}
} else {
let src = std::fs::read_to_string(&path)
.map_err(|e| syn::Error::new(path_lit.span(), e))?;
let cfg = parse_compile_cfg(&mut input)?;
compile(&src, Some(path.to_string_lossy().as_ref()), &cfg)
.map_err(|e| ParseError::new(input.span(), e))?
};
let rv = IncludedShaderSource(feedback);
Ok(rv)
}
}
impl Parse for InlineShaderSource {
fn parse(mut input: ParseStream) -> ParseResult<Self> {
let src = parse_str(&mut input)?;
let cfg = parse_compile_cfg(&mut input)?;
let feedback = compile(&src, None, &cfg)
.map_err(|e| ParseError::new(input.span(), e))?;
let rv = InlineShaderSource(feedback);
Ok(rv)
}
}
fn gen_token_stream(feedback: CompilationFeedback) -> TokenStream {
let CompilationFeedback { spv, dep_paths } = feedback;
(quote! {
{
{ #(let _ = include_bytes!(#dep_paths);)* }
&[#(#spv),*]
}
}).into()
}
#[proc_macro]
pub fn inline_spirv(tokens: TokenStream) -> TokenStream {
let InlineShaderSource(feedback) = parse_macro_input!(tokens as InlineShaderSource);
gen_token_stream(feedback)
}
#[proc_macro]
pub fn include_spirv(tokens: TokenStream) -> TokenStream {
let IncludedShaderSource(feedback) = parse_macro_input!(tokens as IncludedShaderSource);
gen_token_stream(feedback)
}