1extern crate proc_macro;
154use std::convert::TryInto;
155use std::path::{Path, PathBuf};
156
157mod backends;
158
159use proc_macro::TokenStream;
160use quote::quote;
161use syn::parse::{Parse, ParseStream, Result as ParseResult, Error as ParseError};
162use syn::{parse_macro_input, Ident, LitStr, Token};
163
164#[derive(Clone, Copy, PartialEq, Eq)]
165enum InputSourceLanguage {
166 Unknown,
167 Glsl,
168 Hlsl,
169 Wgsl,
170 Spvasm,
171}
172#[derive(Clone, Copy, PartialEq, Eq)]
173enum TargetSpirvVersion {
174 Spirv1_0,
175 Spirv1_1,
176 Spirv1_2,
177 Spirv1_3,
178 Spirv1_4,
179 Spirv1_5,
180 Spirv1_6,
181}
182#[derive(Clone, Copy, PartialEq, Eq)]
183enum TargetEnvironmentType {
184 Vulkan,
185 OpenGL,
186 WebGpu,
187}
188#[derive(Clone, Copy, PartialEq, Eq)]
189enum OptimizationLevel {
190 MinSize,
191 MaxPerformance,
192 None,
193}
194#[derive(Clone, Copy, PartialEq, Eq)]
195enum ShaderKind {
196 Unknown,
197
198 Vertex,
199 TesselationControl,
200 TesselationEvaluation,
201 Geometry,
202 Fragment,
203 Compute,
204 Mesh,
206 Task,
207 RayGeneration,
209 Intersection,
210 AnyHit,
211 ClosestHit,
212 Miss,
213 Callable,
214}
215
216struct ShaderCompilationConfig {
217 lang: InputSourceLanguage,
218 incl_dirs: Vec<PathBuf>,
219 defs: Vec<(String, Option<String>)>,
220 spv_ver: TargetSpirvVersion,
221 env_ty: TargetEnvironmentType,
222 entry: String,
223 optim_lv: OptimizationLevel,
224 debug: bool,
225 kind: ShaderKind,
226 auto_bind: bool,
227 #[cfg(feature = "naga")]
229 y_flip: bool,
230}
231impl Default for ShaderCompilationConfig {
232 fn default() -> Self {
233 ShaderCompilationConfig {
234 lang: InputSourceLanguage::Unknown,
235 incl_dirs: Vec::new(),
236 defs: Vec::new(),
237 spv_ver: TargetSpirvVersion::Spirv1_0,
238 env_ty: TargetEnvironmentType::Vulkan,
239 entry: "main".to_owned(),
240 optim_lv: OptimizationLevel::None,
241 debug: true,
242 kind: ShaderKind::Unknown,
243 auto_bind: false,
244
245 #[cfg(feature = "naga")]
246 y_flip: true,
247 }
248 }
249}
250
251struct CompilationFeedback {
252 spv: Vec<u32>,
253 dep_paths: Vec<String>,
254}
255struct InlineShaderSource(CompilationFeedback);
256struct IncludedShaderSource(CompilationFeedback);
257
258#[inline]
259fn get_base_dir() -> PathBuf {
260 let base_dir = std::env::var("CARGO_MANIFEST_DIR")
261 .expect("`inline-spirv` can only be used in build time");
262 PathBuf::from(base_dir)
263}
264#[inline]
265fn parse_str(input: &mut ParseStream) -> ParseResult<String> {
266 input.parse::<LitStr>()
267 .map(|x| x.value())
268}
269#[inline]
270fn parse_ident(input: &mut ParseStream) -> ParseResult<String> {
271 input.parse::<Ident>()
272 .map(|x| x.to_string())
273}
274
275fn parse_compile_cfg(
276 input: &mut ParseStream
277) -> ParseResult<ShaderCompilationConfig> {
278 let mut cfg = ShaderCompilationConfig::default();
279 while !input.is_empty() {
280 use syn::Error;
281 input.parse::<Token![,]>()?;
283 let k = if let Ok(k) = input.parse::<Ident>() { k } else { break };
284 match &k.to_string() as &str {
285 "glsl" => cfg.lang = InputSourceLanguage::Glsl,
286 "hlsl" => {
287 cfg.lang = InputSourceLanguage::Hlsl;
288 cfg.optim_lv = OptimizationLevel::MaxPerformance;
291 },
292 "wgsl" => cfg.lang = InputSourceLanguage::Wgsl,
293 "spvasm" => cfg.lang = InputSourceLanguage::Spvasm,
294
295 "vert" => cfg.kind = ShaderKind::Vertex,
296 "tesc" => cfg.kind = ShaderKind::TesselationControl,
297 "tese" => cfg.kind = ShaderKind::TesselationEvaluation,
298 "geom" => cfg.kind = ShaderKind::Geometry,
299 "frag" => cfg.kind = ShaderKind::Fragment,
300 "comp" => cfg.kind = ShaderKind::Compute,
301 "mesh" => cfg.kind = ShaderKind::Mesh,
302 "task" => cfg.kind = ShaderKind::Task,
303 "rgen" => cfg.kind = ShaderKind::RayGeneration,
304 "rint" => cfg.kind = ShaderKind::Intersection,
305 "rahit" => cfg.kind = ShaderKind::AnyHit,
306 "rchit" => cfg.kind = ShaderKind::ClosestHit,
307 "rmiss" => cfg.kind = ShaderKind::Miss,
308 "rcall" => cfg.kind = ShaderKind::Callable,
309
310 "I" => {
311 cfg.incl_dirs.push(PathBuf::from(parse_str(input)?))
312 },
313 "D" => {
314 let k = parse_ident(input)?;
315 let v = if input.parse::<Token![=]>().is_ok() {
316 Some(parse_str(input)?)
317 } else { None };
318 cfg.defs.push((k, v));
319 },
320
321 "entry" => {
322 if input.parse::<Token![=]>().is_ok() {
323 cfg.entry = parse_str(input)?.to_owned();
324 }
325 }
326
327 "min_size" => cfg.optim_lv = OptimizationLevel::MinSize,
328 "max_perf" => cfg.optim_lv = OptimizationLevel::MaxPerformance,
329
330 "no_debug" => cfg.debug = false,
331
332 "vulkan" | "vulkan1_0" => {
333 cfg.env_ty = TargetEnvironmentType::Vulkan;
334 cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
335 },
336 "vulkan1_1" => {
337 cfg.env_ty = TargetEnvironmentType::Vulkan;
338 cfg.spv_ver = TargetSpirvVersion::Spirv1_3;
339 },
340 "vulkan1_2" => {
341 cfg.env_ty = TargetEnvironmentType::Vulkan;
342 cfg.spv_ver = TargetSpirvVersion::Spirv1_5;
343 },
344 "opengl" | "opengl4_5" => {
345 cfg.env_ty = TargetEnvironmentType::OpenGL;
346 cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
347 },
348 "webgpu" => {
349 cfg.env_ty = TargetEnvironmentType::WebGpu;
350 cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
351 }
352
353 "spirq1_0" => cfg.spv_ver = TargetSpirvVersion::Spirv1_0,
354 "spirq1_1" => cfg.spv_ver = TargetSpirvVersion::Spirv1_1,
355 "spirq1_2" => cfg.spv_ver = TargetSpirvVersion::Spirv1_2,
356 "spirq1_3" => cfg.spv_ver = TargetSpirvVersion::Spirv1_3,
357 "spirq1_4" => cfg.spv_ver = TargetSpirvVersion::Spirv1_4,
358 "spirq1_5" => cfg.spv_ver = TargetSpirvVersion::Spirv1_5,
359 "spirq1_6" => cfg.spv_ver = TargetSpirvVersion::Spirv1_6,
360
361 "auto_bind" => cfg.auto_bind = true,
362
363 #[cfg(feature = "naga")]
364 "no_y_flip" => cfg.y_flip = false,
365
366 _ => return Err(Error::new(k.span(), "unsupported compilation parameter")),
367 }
368 }
369 Ok(cfg)
370}
371
372fn compile(
373 src: &str,
374 path: Option<&str>,
375 cfg: &ShaderCompilationConfig,
376) -> Result<CompilationFeedback, String> {
377 match backends::spirq_spvasm::compile(src, path, cfg) {
378 Ok(x) => return Ok(x),
379 Err(e) if e != "unsupported source language" => return Err(e),
380 _ => {}
381 }
382 #[cfg(feature = "shaderc")]
383 match backends::shaderc::compile(src, path, cfg) {
384 Ok(x) => return Ok(x),
385 Err(e) if e != "unsupported source language" => return Err(e),
386 _ => {}
387 }
388 #[cfg(feature = "naga")]
389 match backends::naga::compile(src, path, cfg) {
390 Ok(x) => return Ok(x),
391 Err(e) if e != "unsupported source language" => return Err(e),
392 _ => {}
393 }
394 Err("no supported backend found".to_owned())
395}
396
397fn build_spirv_binary(path: &Path) -> Option<Vec<u32>> {
398 use std::fs::File;
399 use std::io::Read;
400 let mut buf = Vec::new();
401 if let Ok(mut f) = File::open(&path) {
402 if buf.len() & 3 != 0 {
403 return None;
405 }
406 f.read_to_end(&mut buf).ok()?;
407 }
408
409 let out = buf.chunks_exact(4)
410 .map(|x| x.try_into().unwrap())
411 .map(match buf[0] {
412 0x03 => u32::from_le_bytes,
413 0x07 => u32::from_be_bytes,
414 _ => return None,
415 })
416 .collect::<Vec<u32>>();
417
418 Some(out)
419}
420
421impl Parse for IncludedShaderSource {
422 fn parse(mut input: ParseStream) -> ParseResult<Self> {
423 use std::ffi::OsStr;
424 let path_lit = input.parse::<LitStr>()?;
425 let path = Path::new(&get_base_dir())
426 .join(&path_lit.value());
427
428 if !path.exists() || !path.is_file() {
429 return Err(ParseError::new(path_lit.span(),
430 format!("{path} is not a valid source file", path=path_lit.value())));
431 }
432
433 let is_spirv = path.is_file() && path.extension() == Some(OsStr::new("spv"));
434 let feedback = if is_spirv {
435 let spv = build_spirv_binary(&path)
436 .ok_or_else(|| syn::Error::new(path_lit.span(), "invalid spirv"))?;
437 CompilationFeedback {
438 spv,
439 dep_paths: vec![],
440 }
441 } else {
442 let src = std::fs::read_to_string(&path)
443 .map_err(|e| syn::Error::new(path_lit.span(), e))?;
444 let cfg = parse_compile_cfg(&mut input)?;
445 compile(&src, Some(path.to_string_lossy().as_ref()), &cfg)
446 .map_err(|e| ParseError::new(input.span(), e))?
447 };
448 let rv = IncludedShaderSource(feedback);
449 Ok(rv)
450 }
451}
452impl Parse for InlineShaderSource {
453 fn parse(mut input: ParseStream) -> ParseResult<Self> {
454 let src = parse_str(&mut input)?;
455 let cfg = parse_compile_cfg(&mut input)?;
456 let feedback = compile(&src, None, &cfg)
457 .map_err(|e| ParseError::new(input.span(), e))?;
458 let rv = InlineShaderSource(feedback);
459 Ok(rv)
460 }
461}
462
463fn gen_token_stream(feedback: CompilationFeedback) -> TokenStream {
464 let CompilationFeedback { spv, dep_paths } = feedback;
465 (quote! {
466 {
467 { #(let _ = include_bytes!(#dep_paths);)* }
468 &[#(#spv),*]
469 }
470 }).into()
471}
472
473#[proc_macro]
476pub fn inline_spirv(tokens: TokenStream) -> TokenStream {
477 let InlineShaderSource(feedback) = parse_macro_input!(tokens as InlineShaderSource);
478 gen_token_stream(feedback)
479}
480#[proc_macro]
483pub fn include_spirv(tokens: TokenStream) -> TokenStream {
484 let IncludedShaderSource(feedback) = parse_macro_input!(tokens as IncludedShaderSource);
485 gen_token_stream(feedback)
486}