extern crate proc_macro;
mod backends;
#[cfg(not(any(feature = "shaderc", feature = "naga")))]
compile_error!("no compiler backend enabled; please specify at least one of \
the following input source features: `glsl`, `hlsl`, `wgsl`");
use proc_macro::TokenStream;
use syn::parse::{Parse, ParseStream, Result as ParseResult, Error as ParseError};
use syn::{parse_macro_input, Ident, LitStr, Token, Expr};
#[derive(Clone, Copy)]
enum InputSourceLanguage {
Unknown,
Glsl,
Hlsl,
Wgsl,
}
#[derive(Clone, Copy)]
enum TargetSpirvVersion {
Spirv1_0,
#[allow(dead_code)]
Spirv1_1,
#[allow(dead_code)]
Spirv1_2,
Spirv1_3,
#[allow(dead_code)]
Spirv1_4,
Spirv1_5,
}
#[derive(Clone, Copy)]
enum TargetEnvironmentType {
Vulkan,
OpenGL,
WebGpu,
}
#[derive(Clone, Copy)]
enum OptimizationLevel {
MinSize,
MaxPerformance,
None,
}
#[derive(Clone, Copy)]
enum ShaderKind {
Unknown,
Vertex,
TesselationControl,
TesselationEvaluation,
Geometry,
Fragment,
Compute,
Mesh,
Task,
RayGeneration,
Intersection,
AnyHit,
ClosestHit,
Miss,
Callable,
}
struct ShaderCompilationConfig {
path: Option<String>,
lang: InputSourceLanguage,
incl_dirs: Vec<String>,
defs: Vec<(String, Option<String>)>,
spv_ver: TargetSpirvVersion,
env_ty: TargetEnvironmentType,
entry: String,
optim_lv: OptimizationLevel,
debug: bool,
kind: ShaderKind,
auto_bind: bool,
#[cfg(feature = "naga")]
y_flip: bool,
}
impl Default for ShaderCompilationConfig {
fn default() -> Self {
ShaderCompilationConfig {
path: None,
lang: InputSourceLanguage::Unknown,
incl_dirs: Vec::new(),
defs: Vec::new(),
spv_ver: TargetSpirvVersion::Spirv1_0,
env_ty: TargetEnvironmentType::Vulkan,
entry: "main".to_owned(),
optim_lv: OptimizationLevel::None,
debug: true,
kind: ShaderKind::Unknown,
auto_bind: false,
#[cfg(feature = "naga")]
y_flip: true,
}
}
}
struct JitSpirv(TokenStream);
#[inline]
fn parse_str(input: &mut ParseStream) -> ParseResult<String> {
input.parse::<LitStr>()
.map(|x| x.value())
}
#[inline]
fn parse_ident(input: &mut ParseStream) -> ParseResult<String> {
input.parse::<Ident>()
.map(|x| x.to_string())
}
fn parse_compile_cfg(
input: &mut ParseStream
) -> ParseResult<ShaderCompilationConfig> {
let mut cfg = ShaderCompilationConfig::default();
while !input.is_empty() {
use syn::Error;
input.parse::<Token![,]>()?;
let k = if let Ok(k) = input.parse::<Ident>() { k } else { break };
match &k.to_string() as &str {
"path" => {
input.parse::<Token![,]>()?;
cfg.path = Some(parse_str(input)?);
},
"glsl" => cfg.lang = InputSourceLanguage::Glsl,
"hlsl" => {
cfg.lang = InputSourceLanguage::Hlsl;
cfg.optim_lv = OptimizationLevel::MaxPerformance;
},
"wgsl" => cfg.lang = InputSourceLanguage::Wgsl,
"vert" => cfg.kind = ShaderKind::Vertex,
"tesc" => cfg.kind = ShaderKind::TesselationControl,
"tese" => cfg.kind = ShaderKind::TesselationEvaluation,
"geom" => cfg.kind = ShaderKind::Geometry,
"frag" => cfg.kind = ShaderKind::Fragment,
"comp" => cfg.kind = ShaderKind::Compute,
"mesh" => cfg.kind = ShaderKind::Mesh,
"task" => cfg.kind = ShaderKind::Task,
"rgen" => cfg.kind = ShaderKind::RayGeneration,
"rint" => cfg.kind = ShaderKind::Intersection,
"rahit" => cfg.kind = ShaderKind::AnyHit,
"rchit" => cfg.kind = ShaderKind::ClosestHit,
"rmiss" => cfg.kind = ShaderKind::Miss,
"rcall" => cfg.kind = ShaderKind::Callable,
"I" => {
cfg.incl_dirs.push(parse_str(input)?)
},
"D" => {
let k = parse_ident(input)?;
let v = if input.parse::<Token![=]>().is_ok() {
Some(parse_str(input)?)
} else { None };
cfg.defs.push((k, v));
},
"entry" => {
if input.parse::<Token![=]>().is_ok() {
cfg.entry = parse_str(input)?.to_owned();
}
}
"min_size" => cfg.optim_lv = OptimizationLevel::MinSize,
"max_perf" => cfg.optim_lv = OptimizationLevel::MaxPerformance,
"no_debug" => cfg.debug = false,
"vulkan" | "vulkan1_0" => {
cfg.env_ty = TargetEnvironmentType::Vulkan;
cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
},
"vulkan1_1" => {
cfg.env_ty = TargetEnvironmentType::Vulkan;
cfg.spv_ver = TargetSpirvVersion::Spirv1_3;
},
"vulkan1_2" => {
cfg.env_ty = TargetEnvironmentType::Vulkan;
cfg.spv_ver = TargetSpirvVersion::Spirv1_5;
},
"opengl" | "opengl4_5" => {
cfg.env_ty = TargetEnvironmentType::OpenGL;
cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
},
"webgpu" => {
cfg.env_ty = TargetEnvironmentType::WebGpu;
cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
}
"auto_bind" => cfg.auto_bind = true,
#[cfg(feature = "naga")]
"no_y_flip" => cfg.y_flip = false,
_ => return Err(Error::new(k.span(), "unsupported compilation parameter")),
}
}
Ok(cfg)
}
fn generate_compile_code(
src: &Expr,
cfg: &ShaderCompilationConfig,
) -> Result<proc_macro::TokenStream, String> {
use quote::quote;
let mut is_valid = false;
let mut out = quote!(Err(String::default()));
if let Ok(generated_code) = backends::naga::generate_compile_code(src, cfg) {
out.extend(quote!(.or_else(#generated_code)));
is_valid = true;
}
if let Ok(generated_code) = backends::shaderc::generate_compile_code(src, cfg) {
out.extend(quote!(.or_else(#generated_code)));
is_valid = true;
}
if !is_valid {
return Err("cannot find a proper shader compiler backend".to_owned());
}
Ok(out.into())
}
impl Parse for JitSpirv {
fn parse(mut input: ParseStream) -> ParseResult<Self> {
let src = input.parse::<Expr>()?;
let cfg = parse_compile_cfg(&mut input)?;
let tokens = generate_compile_code(&src, &cfg)
.map_err(|e| ParseError::new(input.span(), e))?;
Ok(JitSpirv(tokens))
}
}
#[proc_macro]
pub fn jit_spirv(tokens: TokenStream) -> TokenStream {
let JitSpirv(tokens) = parse_macro_input!(tokens as JitSpirv);
tokens
}