use serde::Deserialize;
use std::cell::RefCell;
use syn::parse::{Parse, ParseStream};
use syn::LitStr;
use syn::Token;
mod args;
mod shaderc_serde;
pub fn compile_from<S: Source + Parse>(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let mut tokens = input.into_iter();
let first = proc_macro::TokenStream::from(match tokens.next() {
Some(f) => f,
None => {
return syn::Error::new(proc_macro2::Span::call_site(), "expected string literal")
.to_compile_error()
.into()
}
});
let source: S = syn::parse_macro_input!(first as S);
let input = if let Some(tok) = tokens.next() {
let is_comma = if let proc_macro::TokenTree::Punct(p) = tok.clone() {
p.as_char() == ','
} else {
false
};
let input = tokens.collect::<proc_macro::TokenStream>();
if is_comma || input.is_empty() {
input
} else {
return syn::Error::new(
proc_macro2::Span::from(tok.span()),
format!("unexpected {}", tok.to_string()),
)
.to_compile_error()
.into();
}
} else {
tokens.collect::<proc_macro::TokenStream>()
};
let args: args::Args =
match serde_tokenstream::from_tokenstream(&proc_macro2::TokenStream::from(input)) {
Ok(args) => args,
Err(e) => {
return proc_macro::TokenStream::from(
syn::Error::new(proc_macro2::Span::call_site(), e).to_compile_error(),
)
}
};
match compile(source, args).map_err(|e| syn::Error::new(proc_macro2::Span::call_site(), e)) {
Ok(CompileResult {
warnings,
includes,
data,
}) => if let Some(warn) = warnings {
emit_call_site_warning!(warn);
quote::quote! {
{
#(const _INCLUDE: &[u8] = include_bytes!(#includes);)*
&[#(#data),*]
}
}
} else {
quote::quote! {
{
#(const _INCLUDE: &[u8] = include_bytes!(#includes);)*
&[#(#data),*]
}
}
}
.into(),
Err(err) => proc_macro::TokenStream::from(err.to_compile_error()),
}
}
pub struct CompileResult {
pub warnings: Option<String>,
pub includes: Vec<String>,
pub data: Vec<u32>,
}
pub fn compile<S: Source>(source: S, args: args::Args) -> std::io::Result<CompileResult> {
let source_code = source.source_code()?;
let mut includes: Vec<String> = vec![];
if let Some(path) = source.path() {
includes.push(path);
}
let includes = RefCell::new(includes);
let res = {
let mut options :shaderc::CompileOptions = args.to_options().expect("create shader compile options");
options.set_include_callback(|name, typ, from, _depth| {
let mut path = match typ {
shaderc::IncludeType::Standard => std::env::var_os("CARGO_PKG_DIR")
.map_or_else(|| std::path::PathBuf::from("/"), std::path::PathBuf::from),
shaderc::IncludeType::Relative => std::path::PathBuf::from(from)
.parent()
.map(ToOwned::to_owned)
.unwrap_or_else(|| std::path::PathBuf::from("/")),
};
path.push(name);
let resolved_name = path
.clone()
.into_os_string()
.into_string()
.map_err(|e| format!("path contains invalid utf8: '{:?}'", e))?;
includes.borrow_mut().push(resolved_name.clone());
Ok(shaderc::ResolvedInclude {
resolved_name,
content: std::fs::read_to_string(path).map_err(|e| e.to_string())?,
})
});
let mut compiler = shaderc::Compiler::new().expect("create compiler");
let path = source.path();
compiler
.compile_into_spirv(
&source_code,
args.kind
.or_else(|| path.and_then(guess_shader_kind))
.unwrap_or(shaderc::ShaderKind::InferFromSource),
args.name.as_ref().map_or(&source.name(), String::as_str),
args.entry.as_ref().map_or("main", String::as_str),
Some(&options),
)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?
};
let warnings = if 0 != res.get_num_warnings() {
Some(res.get_warning_messages())
} else {
None
};
Ok(CompileResult {
warnings,
includes: includes.into_inner(),
data: res.as_binary().to_vec(),
})
}
pub trait Source {
fn source_code(&self) -> std::io::Result<String>;
fn path(&self) -> Option<String>;
fn name(&self) -> String;
}
pub struct FileSource(String);
impl Parse for FileSource {
fn parse(input: ParseStream) -> syn::Result<Self> {
let src = Self(input.parse::<LitStr>()?.value());
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
Ok(src)
}
}
impl Parse for InlineSource {
fn parse(input: ParseStream) -> syn::Result<Self> {
let src = Self(input.parse::<LitStr>()?.value());
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
Ok(src)
}
}
impl Source for FileSource {
fn source_code(&self) -> std::io::Result<String> {
std::fs::read_to_string(self.path().unwrap())
.map_err(|e| std::io::Error::new(e.kind(), format!("{}: {}", e, self.path().unwrap())))
}
fn path(&self) -> Option<String> {
let path = std::path::PathBuf::from(&self.0);
if path.is_relative() {
let mut base =
std::path::PathBuf::from(std::env::var_os("CARGO_MANIFEST_DIR").unwrap());
base.push(path);
Some(base.to_string_lossy().into_owned())
} else {
Some(self.0.clone())
}
}
fn name(&self) -> String {
self.path().unwrap()
}
}
#[derive(Deserialize)]
#[serde(transparent)]
pub struct InlineSource(String);
impl Source for InlineSource {
fn source_code(&self) -> std::io::Result<String> {
Ok(self.0.clone())
}
fn path(&self) -> Option<String> {
None
}
fn name(&self) -> String {
String::from("<inline>")
}
}
fn guess_shader_kind(path: String) -> Option<shaderc::ShaderKind> {
use shaderc::ShaderKind::*;
match std::path::Path::new(path.as_str()).extension()?.to_str()? {
"vert" => Some(Vertex),
"frag" => Some(Fragment),
"comp" => Some(Compute),
"geom" => Some(Geometry),
"tesc" => Some(TessControl),
"tese" => Some(TessEvaluation),
"spvasm" => Some(SpirvAssembly),
"rgen" => Some(RayGeneration),
"rahit" => Some(AnyHit),
"rchit" => Some(ClosestHit),
"rmiss" => Some(Miss),
"rint" => Some(Intersection),
"rcall" => Some(Callable),
"task" => Some(Task),
"mesh" => Some(Mesh),
_ => None,
}
}