shaderc-macro 0.1.0

Compile shaders with shaderc using macros at compile time
Documentation
use serde::Deserialize;
use std::cell::RefCell;
use syn::parse::{Parse, ParseStream};
use syn::LitStr;
use syn::Token;

mod args;
mod shaderc_serde;

pub fn compile_from<S: Source + Parse>(
    input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    // get the first token tree for source literal
    let mut tokens = input.into_iter();
    let first = proc_macro::TokenStream::from(match tokens.next() {
        Some(f) => f,
        None => {
            return syn::Error::new(proc_macro2::Span::call_site(), "expected string literal")
                .to_compile_error()
                .into()
        }
    });
    let source: S = syn::parse_macro_input!(first as S);

    // get comma
    let input = if let Some(tok) = tokens.next() {
        let is_comma = if let proc_macro::TokenTree::Punct(p) = tok.clone() {
            p.as_char() == ','
        } else {
            false
        };
        let input = tokens.collect::<proc_macro::TokenStream>();
        if is_comma || input.is_empty() {
            input
        } else {
            return syn::Error::new(
                proc_macro2::Span::from(tok.span()),
                format!("unexpected {}", tok.to_string()),
            )
            .to_compile_error()
            .into();
        }
    } else {
        tokens.collect::<proc_macro::TokenStream>()
    };

    let args: args::Args =
        match serde_tokenstream::from_tokenstream(&proc_macro2::TokenStream::from(input)) {
            Ok(args) => args,
            Err(e) => {
                return proc_macro::TokenStream::from(
                    syn::Error::new(proc_macro2::Span::call_site(), e).to_compile_error(),
                )
            }
        };
    match compile(source, args).map_err(|e| syn::Error::new(proc_macro2::Span::call_site(), e)) {
        Ok(CompileResult {
            warnings,
            includes,
            data,
        }) => if let Some(warn) = warnings {
            emit_call_site_warning!(warn);
            quote::quote! {
                    {
                        #(const _INCLUDE: &[u8] = include_bytes!(#includes);)*
                        &[#(#data),*]
                    }
            }
        } else {
            quote::quote! {
                        {
                            #(const _INCLUDE: &[u8] = include_bytes!(#includes);)*
                            &[#(#data),*]
                        }
            }
        }
        .into(),
        Err(err) => proc_macro::TokenStream::from(err.to_compile_error()),
    }
}

pub struct CompileResult {
    pub warnings: Option<String>,
    pub includes: Vec<String>,
    pub data: Vec<u32>,
}

pub fn compile<S: Source>(source: S, args: args::Args) -> std::io::Result<CompileResult> {
    let source_code = source.source_code()?;
    let mut includes: Vec<String> = vec![];
    if let Some(path) = source.path() {
        includes.push(path);
    }
    let includes = RefCell::new(includes);
    let res = {
        let mut options :shaderc::CompileOptions = args.to_options().expect("create shader compile options");

        options.set_include_callback(|name, typ, from, _depth| {
            let mut path = match typ {
                shaderc::IncludeType::Standard => std::env::var_os("CARGO_PKG_DIR")
                    .map_or_else(|| std::path::PathBuf::from("/"), std::path::PathBuf::from),
                shaderc::IncludeType::Relative => std::path::PathBuf::from(from)
                    .parent()
                    .map(ToOwned::to_owned)
                    .unwrap_or_else(|| std::path::PathBuf::from("/")),
            };
            path.push(name);
            let resolved_name = path
                .clone()
                .into_os_string()
                .into_string()
                .map_err(|e| format!("path contains invalid utf8: '{:?}'", e))?;
            includes.borrow_mut().push(resolved_name.clone());
            Ok(shaderc::ResolvedInclude {
                resolved_name,
                content: std::fs::read_to_string(path).map_err(|e| e.to_string())?,
            })
        });

        let mut compiler = shaderc::Compiler::new().expect("create compiler");
        let path = source.path();

        compiler
            .compile_into_spirv(
                &source_code,
                args.kind
                    .or_else(|| path.and_then(guess_shader_kind))
                    .unwrap_or(shaderc::ShaderKind::InferFromSource),
                args.name.as_ref().map_or(&source.name(), String::as_str),
                args.entry.as_ref().map_or("main", String::as_str),
                Some(&options),
            )
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?
    };
    let warnings = if 0 != res.get_num_warnings() {
        Some(res.get_warning_messages())
    } else {
        None
    };
    Ok(CompileResult {
        warnings,
        includes: includes.into_inner(),
        data: res.as_binary().to_vec(),
    })
}

pub trait Source {
    fn source_code(&self) -> std::io::Result<String>;
    fn path(&self) -> Option<String>;
    fn name(&self) -> String;
}

pub struct FileSource(String);

impl Parse for FileSource {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let src = Self(input.parse::<LitStr>()?.value());
        if input.peek(Token![,]) {
            input.parse::<Token![,]>()?;
        }
        Ok(src)
    }
}

impl Parse for InlineSource {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let src = Self(input.parse::<LitStr>()?.value());
        if input.peek(Token![,]) {
            input.parse::<Token![,]>()?;
        }
        Ok(src)
    }
}

impl Source for FileSource {
    fn source_code(&self) -> std::io::Result<String> {
        std::fs::read_to_string(self.path().unwrap())
            .map_err(|e| std::io::Error::new(e.kind(), format!("{}: {}", e, self.path().unwrap())))
    }

    fn path(&self) -> Option<String> {
        let path = std::path::PathBuf::from(&self.0);
        if path.is_relative() {
            let mut base =
                std::path::PathBuf::from(std::env::var_os("CARGO_MANIFEST_DIR").unwrap());
            base.push(path);
            Some(base.to_string_lossy().into_owned())
        } else {
            Some(self.0.clone())
        }
    }

    fn name(&self) -> String {
        self.path().unwrap()
    }
}

#[derive(Deserialize)]
#[serde(transparent)]
pub struct InlineSource(String);

impl Source for InlineSource {
    fn source_code(&self) -> std::io::Result<String> {
        Ok(self.0.clone())
    }
    fn path(&self) -> Option<String> {
        None
    }
    fn name(&self) -> String {
        String::from("<inline>")
    }
}

fn guess_shader_kind(path: String) -> Option<shaderc::ShaderKind> {
    use shaderc::ShaderKind::*;
    match std::path::Path::new(path.as_str()).extension()?.to_str()? {
        "vert" => Some(Vertex),
        "frag" => Some(Fragment),
        "comp" => Some(Compute),
        "geom" => Some(Geometry),
        "tesc" => Some(TessControl),
        "tese" => Some(TessEvaluation),
        "spvasm" => Some(SpirvAssembly),
        "rgen" => Some(RayGeneration),
        "rahit" => Some(AnyHit),
        "rchit" => Some(ClosestHit),
        "rmiss" => Some(Miss),
        "rint" => Some(Intersection),
        "rcall" => Some(Callable),
        "task" => Some(Task),
        "mesh" => Some(Mesh),
        _ => None,
    }
}