rafx_shader_processor/
lib.rs

1use std::error::Error;
2use std::path::{Path, PathBuf};
3use structopt::StructOpt;
4
5mod parse_source;
6use parse_source::AnnotationText;
7use parse_source::DeclarationText;
8
9mod parse_declarations;
10
11mod include;
12use crate::parse_declarations::ParseDeclarationsResult;
13use crate::parse_source::PreprocessorState;
14use crate::reflect::ShaderProcessorRefectionData;
15use crate::shader_types::{TypeAlignmentInfo, UserType};
16use fnv::{FnvHashMap, FnvHashSet};
17use include::include_impl;
18use include::IncludeType;
19use rafx_api::{
20    RafxHashedShaderPackage, RafxShaderPackage, RafxShaderPackageDx12, RafxShaderPackageGles2,
21    RafxShaderPackageGles3, RafxShaderPackageMetal, RafxShaderPackageVulkan,
22};
23use shaderc::{CompilationArtifact, Compiler, ShaderKind};
24use spirv_cross::glsl::Target;
25use spirv_cross::spirv::{Ast, ShaderResources};
26
27mod codegen;
28
29mod reflect;
30
31mod shader_types;
32
33const PREPROCESSOR_DEF_PLATFORM_RUST_CODEGEN: &'static str = "PLATFORM_RUST_CODEGEN";
34const PREPROCESSOR_DEF_PLATFORM_DX12: &'static str = "PLATFORM_DX12";
35const PREPROCESSOR_DEF_PLATFORM_VULKAN: &'static str = "PLATFORM_VULKAN";
36const PREPROCESSOR_DEF_PLATFORM_METAL: &'static str = "PLATFORM_METAL";
37const PREPROCESSOR_DEF_PLATFORM_GLES2: &'static str = "PLATFORM_GLES2";
38const PREPROCESSOR_DEF_PLATFORM_GLES3: &'static str = "PLATFORM_GLES3";
39
40#[derive(Clone, Copy, Debug)]
41enum RsFileType {
42    Lib,
43    Mod,
44}
45
46#[derive(Debug)]
47struct RsFileOption {
48    path: PathBuf,
49    file_type: RsFileType,
50}
51
52#[derive(StructOpt, Debug)]
53pub struct ShaderProcessorArgs {
54    //
55    // For one file at a time
56    //
57    #[structopt(name = "glsl-file", long, parse(from_os_str))]
58    pub glsl_file: Option<PathBuf>,
59    #[structopt(name = "spv-file", long, parse(from_os_str))]
60    pub spv_file: Option<PathBuf>,
61    #[structopt(name = "rs-file", long, parse(from_os_str))]
62    pub rs_file: Option<PathBuf>,
63    #[structopt(name = "dx12-generated-src-file", long, parse(from_os_str))]
64    pub dx12_generated_src_file: Option<PathBuf>,
65    #[structopt(name = "metal-generated-src-file", long, parse(from_os_str))]
66    pub metal_generated_src_file: Option<PathBuf>,
67    #[structopt(name = "gles2-generated-src-file", long, parse(from_os_str))]
68    pub gles2_generated_src_file: Option<PathBuf>,
69    #[structopt(name = "gles3-generated-src-file", long, parse(from_os_str))]
70    pub gles3_generated_src_file: Option<PathBuf>,
71    #[structopt(name = "cooked-shader-file", long, parse(from_os_str))]
72    pub cooked_shader_file: Option<PathBuf>,
73
74    //
75    // For batch processing a folder
76    //
77    #[structopt(name = "glsl-path", long, parse(from_os_str))]
78    pub glsl_files: Option<PathBuf>,
79    #[structopt(name = "spv-path", long, parse(from_os_str))]
80    pub spv_path: Option<PathBuf>,
81    #[structopt(name = "rs-lib-path", long, parse(from_os_str))]
82    pub rs_lib_path: Option<PathBuf>,
83    #[structopt(name = "rs-mod-path", long, parse(from_os_str))]
84    pub rs_mod_path: Option<PathBuf>,
85    #[structopt(name = "dx12-generated-src-path", long, parse(from_os_str))]
86    pub dx12_generated_src_path: Option<PathBuf>,
87    #[structopt(name = "metal-generated-src-path", long, parse(from_os_str))]
88    pub metal_generated_src_path: Option<PathBuf>,
89    #[structopt(name = "gles2-generated-src-path", long, parse(from_os_str))]
90    pub gles2_generated_src_path: Option<PathBuf>,
91    #[structopt(name = "gles3-generated-src-path", long, parse(from_os_str))]
92    pub gles3_generated_src_path: Option<PathBuf>,
93    #[structopt(name = "cooked-shaders-path", long, parse(from_os_str))]
94    pub cooked_shaders_path: Option<PathBuf>,
95
96    #[structopt(name = "shader-kind", long)]
97    pub shader_kind: Option<String>,
98
99    #[structopt(name = "trace", long)]
100    pub trace: bool,
101
102    #[structopt(name = "optimize-shaders", long)]
103    pub optimize_shaders: bool,
104
105    #[structopt(name = "package-vk", long)]
106    pub package_vk: bool,
107    #[structopt(name = "package-dx12", long)]
108    pub package_dx12: bool,
109    #[structopt(name = "package-metal", long)]
110    pub package_metal: bool,
111    #[structopt(name = "package-gles2", long)]
112    pub package_gles2: bool,
113    #[structopt(name = "package-gles3", long)]
114    pub package_gles3: bool,
115    #[structopt(name = "package-all", long)]
116    pub package_all: bool,
117
118    #[structopt(name = "for-rafx-framework-crate", long)]
119    pub for_rafx_framework_crate: bool,
120}
121
122pub fn run(args: &ShaderProcessorArgs) -> Result<(), Box<dyn Error>> {
123    log::trace!("Shader processor args: {:#?}", args);
124    if args.rs_lib_path.is_some() && args.rs_mod_path.is_some() {
125        Err("Both --rs-lib-path and --rs-mod-path were provided, using both at the same time is not supported.")?;
126    }
127
128    let rs_file_option = if let Some(path) = &args.rs_lib_path {
129        Some(RsFileOption {
130            path: path.clone(),
131            file_type: RsFileType::Lib,
132        })
133    } else if let Some(path) = &args.rs_mod_path {
134        Some(RsFileOption {
135            path: path.clone(),
136            file_type: RsFileType::Mod,
137        })
138    } else {
139        None
140    };
141
142    if let Some(glsl_file) = &args.glsl_file {
143        //
144        // Handle a single file given via --glsl_file. In this mode, the output files are explicit
145        //
146        log::info!("Processing file {:?}", glsl_file);
147
148        //
149        // Try to determine what kind of shader this is from the file name
150        //
151        let shader_kind = shader_kind_from_args(args)
152            .or_else(|| deduce_default_shader_kind_from_path(glsl_file))
153            .unwrap_or(shaderc::ShaderKind::InferFromSource);
154
155        //
156        // Process this shader and write to output files
157        //
158        process_glsl_shader(
159            glsl_file,
160            args.spv_file.as_ref(),
161            &rs_file_option,
162            args.dx12_generated_src_path.as_ref(),
163            args.metal_generated_src_file.as_ref(),
164            args.gles2_generated_src_file.as_ref(),
165            args.gles3_generated_src_file.as_ref(),
166            args.cooked_shader_file.as_ref(),
167            shader_kind,
168            &args,
169        )
170        .map_err(|x| format!("{}: {}", glsl_file.to_string_lossy(), x.to_string()))?;
171
172        Ok(())
173    } else if let Some(glsl_files) = &args.glsl_files {
174        log::trace!("glsl files {:?}", args.glsl_files);
175        process_directory(glsl_files, &args, &rs_file_option)
176    } else {
177        Ok(())
178    }
179}
180
181//
182// Handle a batch of file patterns (such as *.frag) via --glsl_files. Infer output files
183// based on other args given in the form of output directories
184//
185fn process_directory(
186    glsl_files: &PathBuf,
187    args: &ShaderProcessorArgs,
188    rs_file_option: &Option<RsFileOption>,
189) -> Result<(), Box<dyn Error>> {
190    // This will accumulate rust module names so we can produce a lib.rs if needed
191    let mut module_names = FnvHashMap::<PathBuf, FnvHashSet<String>>::default();
192
193    log::trace!("GLSL Root Dir: {:?}", glsl_files);
194
195    let glob_walker = globwalk::GlobWalkerBuilder::from_patterns(
196        glsl_files.to_str().unwrap(),
197        &["*.{vert,frag,comp}"],
198    )
199    .file_type(globwalk::FileType::FILE)
200    .build()?;
201
202    for glob in glob_walker {
203        //
204        // Determine the files we will write out
205        //
206        let glsl_file = glob?;
207        log::info!("Processing file {:?}", glsl_file.path());
208
209        let file_name = glsl_file.file_name().to_string_lossy();
210
211        let empty_path = PathBuf::new();
212        let outfile_prefix = glsl_file
213            .path()
214            .strip_prefix(glsl_files)?
215            .parent()
216            .unwrap_or(&empty_path);
217
218        let rs_module_name = file_name.to_string().to_lowercase().replace(".", "_");
219        let rs_name = format!("{}.rs", rs_module_name);
220        let rs_file_option = rs_file_option.as_ref().map(|x| RsFileOption {
221            path: x.path.join(outfile_prefix).join(rs_name),
222            file_type: x.file_type,
223        });
224
225        let spv_name = format!("{}.spv", file_name);
226        let spv_path = args
227            .spv_path
228            .as_ref()
229            .map(|x| x.join(outfile_prefix).join(spv_name));
230
231        let dx12_src_name = format!("{}.hlsl", file_name);
232        let dx12_generated_src_path = args
233            .dx12_generated_src_path
234            .as_ref()
235            .map(|x| x.join(outfile_prefix).join(dx12_src_name));
236
237        let metal_src_name = format!("{}.metal", file_name);
238        let metal_generated_src_path = args
239            .metal_generated_src_path
240            .as_ref()
241            .map(|x| x.join(outfile_prefix).join(metal_src_name));
242
243        let gles2_src_name = format!("{}.gles2", file_name);
244        let gles2_generated_src_path = args
245            .gles2_generated_src_path
246            .as_ref()
247            .map(|x| x.join(outfile_prefix).join(gles2_src_name));
248
249        let gles3_src_name = format!("{}.gles3", file_name);
250        let gles3_generated_src_path = args
251            .gles3_generated_src_path
252            .as_ref()
253            .map(|x| x.join(outfile_prefix).join(gles3_src_name));
254
255        let cooked_shader_name = format!("{}.cookedshaderpackage", file_name);
256        let cooked_shader_path = args
257            .cooked_shaders_path
258            .as_ref()
259            .map(|x| x.join(outfile_prefix).join(cooked_shader_name));
260
261        //
262        // Try to determine what kind of shader this is from the file name
263        //
264        let shader_kind = shader_kind_from_args(args)
265            .or_else(|| deduce_default_shader_kind_from_path(glsl_file.path()))
266            .unwrap_or(shaderc::ShaderKind::InferFromSource);
267
268        //
269        // Process this shader and write to output files
270        //
271        process_glsl_shader(
272            glsl_file.path(),
273            spv_path.as_ref(),
274            &rs_file_option,
275            dx12_generated_src_path.as_ref(),
276            metal_generated_src_path.as_ref(),
277            gles2_generated_src_path.as_ref(),
278            gles3_generated_src_path.as_ref(),
279            cooked_shader_path.as_ref(),
280            shader_kind,
281            &args,
282        )
283        .map_err(|x| format!("{}: {}", glsl_file.path().to_string_lossy(), x.to_string()))?;
284
285        //
286        // Add the module name to this list so we can generate a lib.rs later
287        //
288        if rs_file_option.is_some() {
289            let module_names = module_names
290                .entry(outfile_prefix.to_path_buf())
291                .or_default();
292            module_names.insert(rs_module_name.clone());
293        }
294    }
295    //
296    // Generate lib.rs or mod.rs files that includes all the compiled shaders
297    //
298    if let Some(rs_path) = &rs_file_option {
299        // First ensure that for any nested submodules, they are declared in lib.rs/mod.rs files in
300        // the parent dirs
301        let outfile_prefixes: Vec<_> = module_names.keys().cloned().collect();
302        for mut outfile_prefix in outfile_prefixes {
303            while let Some(parent) = outfile_prefix.parent() {
304                let new_module_name = outfile_prefix
305                    .file_name()
306                    .unwrap()
307                    .to_string_lossy()
308                    .to_string();
309
310                log::trace!("add module {:?} to {:?}", new_module_name, parent);
311
312                let module_names = module_names.entry(parent.to_path_buf()).or_default();
313                module_names.insert(new_module_name);
314
315                outfile_prefix = parent.to_path_buf();
316            }
317        }
318
319        // Generate all lib.rs/mod.rs files
320        for (outfile_prefix, module_names) in module_names {
321            let module_filename = match rs_path.file_type {
322                RsFileType::Lib => "lib.rs",
323                RsFileType::Mod => "mod.rs",
324            };
325            let lib_file_path = rs_path.path.join(outfile_prefix).join(module_filename);
326            log::trace!("Write lib/mod file {:?} {:?}", lib_file_path, module_names);
327
328            let mut lib_file_string = String::default();
329            lib_file_string += "// This code is auto-generated by the shader processor.\n\n";
330            lib_file_string += "#![allow(dead_code)]\n\n";
331
332            for module_name in module_names {
333                lib_file_string += &format!("pub mod {};\n", module_name);
334            }
335
336            write_output_file(&lib_file_path, lib_file_string)?;
337        }
338    }
339
340    Ok(())
341}
342
343fn process_glsl_shader(
344    glsl_file: &Path,
345    spv_file: Option<&PathBuf>,
346    rs_file: &Option<RsFileOption>,
347    dx12_generated_src_file: Option<&PathBuf>,
348    metal_generated_src_file: Option<&PathBuf>,
349    gles2_generated_src_file: Option<&PathBuf>,
350    gles3_generated_src_file: Option<&PathBuf>,
351    cooked_shader_file: Option<&PathBuf>,
352    shader_kind: shaderc::ShaderKind,
353    args: &ShaderProcessorArgs,
354) -> Result<(), Box<dyn Error>> {
355    log::trace!("--- Start processing shader job ---");
356    log::trace!("glsl: {:?}", glsl_file);
357    log::trace!("spv: {:?}", spv_file);
358    log::trace!("rs: {:?}", rs_file);
359    log::trace!("dx12: {:?}", dx12_generated_src_file);
360    log::trace!("metal: {:?}", metal_generated_src_file);
361    log::trace!("gles2: {:?}", gles2_generated_src_file);
362    log::trace!("gles3: {:?}", gles3_generated_src_file);
363    log::trace!("cooked: {:?}", cooked_shader_file);
364    log::trace!("shader kind: {:?}", shader_kind);
365
366    let package_vk = (args.package_all || args.package_vk) && cooked_shader_file.is_some();
367    let package_dx12 = (args.package_all || args.package_dx12) && cooked_shader_file.is_some();
368    let package_metal = (args.package_all || args.package_metal) && cooked_shader_file.is_some();
369    let package_gles2 = (args.package_all || args.package_gles2) && cooked_shader_file.is_some();
370    let package_gles3 = (args.package_all || args.package_gles3) && cooked_shader_file.is_some();
371
372    log::trace!(
373        "package VK: {} dx12: {} Metal: {} GLES2: {} GLES3: {}",
374        package_vk,
375        package_dx12,
376        package_metal,
377        package_gles2,
378        package_gles3
379    );
380
381    if cooked_shader_file.is_some()
382        && !(package_vk || package_dx12 || package_metal || package_gles2 || package_gles3)
383    {
384        Err("A cooked shader file or path was specified but no shader types are specified to package. Pass --package-vk, --package-dx12, --package-metal, --package-gles2, --package-gles3, or --package-all")?;
385    }
386
387    let code = std::fs::read_to_string(&glsl_file)?;
388    let entry_point_name = "main";
389
390    //
391    // First, compile the code with shaderc. This will validate that it's well-formed. We will also
392    // use the produced spv to create reflection data. This first pass must be UNOPTIMIZED so that
393    // we don't drop reflection data for unused elements.
394    //
395    // We want to preserve unused fields so that the rust API we generate does not substantially
396    // change and cause spurious compile errors just because a line of code gets commented out in
397    // the shader. (In the future we may want to generate the API but make it a noop.)
398    //
399    let generate_reflection_data = rs_file.is_some()
400        || cooked_shader_file.is_some()
401        || dx12_generated_src_file.is_some()
402        || metal_generated_src_file.is_some()
403        || gles2_generated_src_file.is_some();
404
405    let require_semantics = cooked_shader_file.is_some() || dx12_generated_src_file.is_some();
406    let compiler = shaderc::Compiler::new().unwrap();
407
408    let compile_parameters = CompileParameters {
409        glsl_file,
410        shader_kind,
411        code: &code,
412        entry_point_name,
413        generate_reflection_data,
414        require_semantics,
415        compiler: &compiler,
416    };
417
418    let rust_code = if rs_file.is_some() {
419        let mut compile_result =
420            compile_glsl(&compile_parameters, PREPROCESSOR_DEF_PLATFORM_RUST_CODEGEN)?;
421
422        log::trace!("{:?}: generate rust code", glsl_file);
423        let reflected_entry_point = compile_result
424            .reflection_data
425            .as_ref()
426            .unwrap()
427            .reflection
428            .iter()
429            .find(|x| x.rafx_api_reflection.entry_point_name == entry_point_name)
430            .ok_or_else(|| {
431                format!(
432                    "Could not find entry point {} in compiled shader file",
433                    entry_point_name
434                )
435            })?;
436
437        //
438        // Generate rust code that matches up with the shader
439        //
440        log::trace!("{:?}: generate rust code", glsl_file);
441        Some(codegen::generate_rust_code(
442            &compile_result.builtin_types,
443            &mut compile_result.user_types,
444            &compile_result.parsed_declarations,
445            //&spirv_reflect_module,
446            &reflected_entry_point,
447            args.for_rafx_framework_crate,
448        )?)
449    } else {
450        None
451    };
452
453    let vk_output = if spv_file.is_some() || package_vk {
454        Some(cross_compile_to_vulkan(
455            glsl_file,
456            &compile_parameters,
457            &args,
458        )?)
459    } else {
460        None
461    };
462
463    let dx12_output = if dx12_generated_src_file.is_some() || package_dx12 {
464        Some(cross_compile_to_dx12(glsl_file, &compile_parameters)?)
465    } else {
466        None
467    };
468
469    let metal_output = if metal_generated_src_file.is_some() || package_metal {
470        Some(cross_compile_to_metal(glsl_file, &compile_parameters)?)
471    } else {
472        None
473    };
474
475    let gles2_output = if gles2_generated_src_file.is_some() || package_gles2 {
476        Some(cross_compile_to_gles2(glsl_file, &compile_parameters)?)
477    } else {
478        None
479    };
480
481    let gles3_output = if gles3_generated_src_file.is_some() || package_gles3 {
482        Some(cross_compile_to_gles3(glsl_file, &compile_parameters)?)
483    } else {
484        None
485    };
486
487    //
488    // Write out the spv and rust files if desired
489    //
490    if let Some(spv_file) = &spv_file {
491        write_output_file(spv_file, &vk_output.as_ref().unwrap().vk_spv)?;
492    }
493
494    if let Some(rs_file) = &rs_file {
495        write_output_file(&rs_file.path, rust_code.unwrap())?;
496    }
497
498    if let Some(dx12_generated_src_file) = &dx12_generated_src_file {
499        write_output_file(
500            dx12_generated_src_file,
501            &dx12_output.as_ref().unwrap().dx12_src,
502        )?;
503    }
504
505    if let Some(metal_generated_src_file) = &metal_generated_src_file {
506        write_output_file(
507            metal_generated_src_file,
508            &metal_output.as_ref().unwrap().metal_src,
509        )?;
510    }
511
512    if let Some(gles2_generated_src_file) = &gles2_generated_src_file {
513        write_output_file(
514            gles2_generated_src_file,
515            &gles2_output.as_ref().unwrap().gles2_src,
516        )?;
517    }
518
519    if let Some(gles3_generated_src_file) = &gles3_generated_src_file {
520        write_output_file(
521            gles3_generated_src_file,
522            &gles3_output.as_ref().unwrap().gles3_src,
523        )?;
524    }
525
526    // Don't worry about the return value
527    log::trace!("{:?}: cook shader", glsl_file);
528    let cooked_shader = if cooked_shader_file.is_some() {
529        let mut shader_package = RafxShaderPackage::default();
530
531        if package_vk {
532            let vk_output = vk_output.unwrap();
533            shader_package.vk = Some(RafxShaderPackageVulkan::SpvBytes(vk_output.vk_spv));
534            shader_package.vk_reflection = vk_output.reflection_data.map(|x| x.reflection);
535        };
536
537        if package_dx12 {
538            let dx12_output = dx12_output.unwrap();
539            shader_package.dx12 = Some(RafxShaderPackageDx12::Src(dx12_output.dx12_src));
540            shader_package.dx12_reflection = dx12_output.reflection_data.map(|x| x.reflection);
541        };
542
543        if package_metal {
544            let metal_output = metal_output.unwrap();
545            shader_package.metal = Some(RafxShaderPackageMetal::Src(metal_output.metal_src));
546            shader_package.metal_reflection = metal_output.reflection_data.map(|x| x.reflection);
547        };
548
549        if package_gles2 {
550            let gles2_output = gles2_output.unwrap();
551            shader_package.gles2 = Some(RafxShaderPackageGles2::Src(gles2_output.gles2_src));
552            shader_package.gles2_reflection = gles2_output.reflection_data.map(|x| x.reflection);
553        };
554
555        if package_gles3 {
556            let gles3_output = gles3_output.unwrap();
557            shader_package.gles3 = Some(RafxShaderPackageGles3::Src(gles3_output.gles3_src));
558            shader_package.gles3_reflection = gles3_output.reflection_data.map(|x| x.reflection);
559        };
560
561        shader_package.debug_name =
562            Some(glsl_file.file_name().unwrap().to_string_lossy().to_string());
563        let hashed_shader_package = RafxHashedShaderPackage::new(shader_package);
564
565        let serialized = bincode::serialize(&hashed_shader_package)
566            .map_err(|x| format!("Failed to serialize cooked shader: {}", x))?;
567        Some(serialized)
568    } else {
569        None
570    };
571
572    if let Some(cooked_shader_file) = &cooked_shader_file {
573        write_output_file(cooked_shader_file, cooked_shader.unwrap())?;
574    }
575
576    Ok(())
577}
578
579struct CompileParameters<'a> {
580    glsl_file: &'a Path,
581    shader_kind: ShaderKind,
582    code: &'a str,
583    entry_point_name: &'a str,
584    generate_reflection_data: bool,
585    require_semantics: bool,
586    compiler: &'a Compiler,
587}
588
589struct CompileResult {
590    unoptimized_spv: CompilationArtifact,
591    parsed_declarations: ParseDeclarationsResult,
592    ast: Ast<Target>,
593    user_types: FnvHashMap<String, UserType>,
594    builtin_types: FnvHashMap<String, TypeAlignmentInfo>,
595    reflection_data: Option<ShaderProcessorRefectionData>,
596}
597
598fn try_load_override_src(
599    original_path: &Path,
600    extension: &str,
601) -> Result<Option<String>, Box<dyn Error>> {
602    let mut override_path = original_path.as_os_str().to_os_string();
603    override_path.push(extension);
604    let override_path = PathBuf::from(override_path);
605    if override_path.exists() {
606        log::info!(
607            "  Override shader {:?} with {:?}",
608            original_path,
609            override_path.to_string_lossy()
610        );
611
612        let override_src = std::fs::read_to_string(&override_path)?;
613
614        // We want to inline all the #includes because we are packaging the source for compilation
615        // on target hardware and it won't be able to #include dependencies.
616        let preprocessed_src =
617            parse_source::inline_includes_in_override_src(&override_path, &override_src)?;
618
619        Ok(Some(preprocessed_src))
620    } else {
621        Ok(None)
622    }
623}
624
625fn compile_glsl(
626    parameters: &CompileParameters,
627    platform_define: &str,
628) -> Result<CompileResult, Box<dyn Error>> {
629    log::trace!("{:?}: compile unoptimized", parameters.glsl_file);
630    let (unoptimized_spv, parsed_source) = {
631        let mut compile_options = shaderc::CompileOptions::new().unwrap();
632        compile_options.set_include_callback(include::shaderc_include_callback);
633        compile_options.set_generate_debug_info();
634        compile_options.add_macro_definition(platform_define, Some("1"));
635
636        log::trace!("compile to spriv for platform {:?}", platform_define);
637
638        let unoptimized_spv = parameters.compiler.compile_into_spirv(
639            &parameters.code,
640            parameters.shader_kind,
641            parameters.glsl_file.to_str().unwrap(),
642            parameters.entry_point_name,
643            Some(&compile_options),
644        )?;
645
646        log::trace!("{:?}: parse glsl", parameters.glsl_file);
647
648        let mut preprocessor_state = PreprocessorState::default();
649        preprocessor_state.add_define(platform_define.to_string(), "1".to_string());
650        let parsed_source = parse_source::parse_glsl_src(
651            &parameters.glsl_file,
652            &parameters.code,
653            &mut preprocessor_state,
654        )?;
655
656        (unoptimized_spv, parsed_source)
657    };
658
659    //
660    // Read the unoptimized spv into spirv_cross so that we can grab reflection data
661    //
662    log::trace!("{:?}: read spirv_cross module", parameters.glsl_file);
663    let spirv_cross_module = spirv_cross::spirv::Module::from_words(unoptimized_spv.as_binary());
664
665    //
666    // Parse the declarations that were extracted from the source file
667    //
668    log::trace!("{:?}: parse declarations", parameters.glsl_file);
669    let parsed_declarations = parse_declarations::parse_declarations(&parsed_source.declarations)?;
670    let is_compute_shader = normalize_shader_kind(parameters.shader_kind) == ShaderKind::Compute;
671    if parsed_declarations.group_size.is_some() && !is_compute_shader {
672        Err("The shader is not a compute shader but a group size was specified")?;
673    } else if parsed_declarations.group_size.is_none() && is_compute_shader {
674        Err("The shader is a compute shader but a group size was not specified. Expected to find something like `layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;` in the shader")?;
675    }
676
677    log::trace!("{:?}: generate spirv_cross ast", parameters.glsl_file);
678    let mut spirv_cross_glsl_options = spirv_cross::glsl::CompilerOptions::default();
679    spirv_cross_glsl_options.vulkan_semantics = true;
680    let mut ast = spirv_cross::spirv::Ast::<spirv_cross::glsl::Target>::parse(&spirv_cross_module)?;
681    ast.set_compiler_options(&spirv_cross_glsl_options)?;
682
683    log::trace!("{:?}: generate shader types", parameters.glsl_file);
684    let user_types = shader_types::create_user_type_lookup(&parsed_declarations)?;
685    let builtin_types = shader_types::create_builtin_type_lookup();
686
687    let reflected_data = if parameters.generate_reflection_data {
688        log::trace!("{:?}: generate reflection data", parameters.glsl_file);
689        Some(reflect::reflect_data(
690            &builtin_types,
691            &user_types,
692            &ast,
693            &parsed_declarations,
694            parameters.require_semantics,
695        )?)
696    } else {
697        None
698    };
699
700    Ok(CompileResult {
701        unoptimized_spv,
702        parsed_declarations,
703        ast,
704        user_types,
705        builtin_types,
706        reflection_data: reflected_data,
707    })
708}
709
710pub struct CrossCompileOutputVulkan {
711    vk_spv: Vec<u8>,
712    reflection_data: Option<ShaderProcessorRefectionData>,
713}
714
715fn cross_compile_to_vulkan(
716    glsl_file: &Path,
717    compile_parameters: &CompileParameters,
718    args: &ShaderProcessorArgs,
719) -> Result<CrossCompileOutputVulkan, Box<dyn Error>> {
720    log::trace!("{:?}: create vulkan", glsl_file);
721    let compile_result = compile_glsl(compile_parameters, PREPROCESSOR_DEF_PLATFORM_VULKAN)?;
722
723    let vk_spv = if args.optimize_shaders {
724        let mut compile_options = shaderc::CompileOptions::new().unwrap();
725        compile_options.set_include_callback(include::shaderc_include_callback);
726        compile_options.set_optimization_level(shaderc::OptimizationLevel::Performance);
727        //NOTE: Could also use shaderc::OptimizationLevel::Size
728
729        compile_parameters
730            .compiler
731            .compile_into_spirv(
732                compile_parameters.code,
733                compile_parameters.shader_kind,
734                glsl_file.to_str().unwrap(),
735                compile_parameters.entry_point_name,
736                Some(&compile_options),
737            )?
738            .as_binary_u8()
739            .to_vec()
740    } else {
741        compile_result.unoptimized_spv.as_binary_u8().to_vec()
742    };
743
744    Ok(CrossCompileOutputVulkan {
745        vk_spv,
746        reflection_data: compile_result.reflection_data,
747    })
748}
749
750pub struct CrossCompileOutputDx12 {
751    dx12_src: String,
752    reflection_data: Option<ShaderProcessorRefectionData>,
753}
754
755fn cross_compile_to_dx12(
756    glsl_file: &Path,
757    compile_parameters: &CompileParameters,
758) -> Result<CrossCompileOutputDx12, Box<dyn Error>> {
759    log::trace!("{:?}: create dx12", glsl_file);
760    let compile_result = compile_glsl(compile_parameters, PREPROCESSOR_DEF_PLATFORM_DX12)?;
761
762    let dx12_src = if let Some(src) = try_load_override_src(glsl_file, ".hlsl")? {
763        src
764    } else {
765        let spirv_cross_module =
766            spirv_cross::spirv::Module::from_words(compile_result.unoptimized_spv.as_binary());
767
768        let mut hlsl_ast =
769            spirv_cross::spirv::Ast::<spirv_cross::hlsl::Target>::parse(&spirv_cross_module)?;
770        let mut spirv_cross_hlsl_options = spirv_cross::hlsl::CompilerOptions::default();
771        spirv_cross_hlsl_options.shader_model = spirv_cross::hlsl::ShaderModel::V6_0;
772        spirv_cross_hlsl_options.flatten_matrix_vertex_input_semantics = true;
773        //DX12TODO: We want something more fine-grained than this
774        spirv_cross_hlsl_options.force_storage_buffer_as_uav = true;
775
776        for assignment in &compile_result
777            .reflection_data
778            .as_ref()
779            .unwrap()
780            .hlsl_register_assignments
781        {
782            hlsl_ast.add_resource_binding(assignment)?;
783        }
784
785        for remap in &compile_result
786            .reflection_data
787            .as_ref()
788            .unwrap()
789            .hlsl_vertex_attribute_remaps
790        {
791            // We require semantics to produce HLSL, an error should be thrown earlier if they are missing
792            assert!(!remap.semantic.is_empty());
793            if !remap.semantic.is_empty() {
794                hlsl_ast.add_vertex_attribute_remap(remap)?;
795            }
796        }
797
798        hlsl_ast.set_compiler_options(&spirv_cross_hlsl_options)?;
799
800        hlsl_ast.compile()?
801    };
802
803    Ok(CrossCompileOutputDx12 {
804        dx12_src,
805        reflection_data: compile_result.reflection_data,
806    })
807}
808
809pub struct CrossCompileOutputMetal {
810    metal_src: String,
811    reflection_data: Option<ShaderProcessorRefectionData>,
812}
813
814fn cross_compile_to_metal(
815    glsl_file: &Path,
816    compile_parameters: &CompileParameters,
817) -> Result<CrossCompileOutputMetal, Box<dyn Error>> {
818    log::trace!("{:?}: create msl", glsl_file);
819    let compile_result = compile_glsl(compile_parameters, PREPROCESSOR_DEF_PLATFORM_METAL)?;
820
821    let metal_src = if let Some(src) = try_load_override_src(glsl_file, ".metal")? {
822        src
823    } else {
824        let spirv_cross_module =
825            spirv_cross::spirv::Module::from_words(compile_result.unoptimized_spv.as_binary());
826
827        let mut msl_ast =
828            spirv_cross::spirv::Ast::<spirv_cross::msl::Target>::parse(&spirv_cross_module)?;
829        let mut spirv_cross_msl_options = spirv_cross::msl::CompilerOptions::default();
830        spirv_cross_msl_options.version = spirv_cross::msl::Version::V2_1;
831        spirv_cross_msl_options.enable_argument_buffers = true;
832        spirv_cross_msl_options.force_active_argument_buffer_resources = true;
833        //TODO: Add equivalent to --msl-no-clip-distance-user-varying
834
835        //TODO: Set this up
836        spirv_cross_msl_options.resource_binding_overrides = compile_result
837            .reflection_data
838            .as_ref()
839            .unwrap()
840            .msl_argument_buffer_assignments
841            .clone();
842        //println!(" binding overrides {:?}", spirv_cross_msl_options.resource_binding_overrides);
843        //spirv_cross_msl_options.vertex_attribute_overrides
844        spirv_cross_msl_options.const_samplers = compile_result
845            .reflection_data
846            .as_ref()
847            .unwrap()
848            .msl_const_samplers
849            .clone();
850
851        msl_ast.set_compiler_options(&spirv_cross_msl_options)?;
852        msl_ast.compile()?
853    };
854
855    Ok(CrossCompileOutputMetal {
856        metal_src,
857        reflection_data: compile_result.reflection_data,
858    })
859}
860
861pub struct CrossCompileOutputGles3 {
862    gles3_src: String,
863    reflection_data: Option<ShaderProcessorRefectionData>,
864}
865
866fn cross_compile_to_gles3(
867    glsl_file: &Path,
868    compile_parameters: &CompileParameters,
869) -> Result<CrossCompileOutputGles3, Box<dyn Error>> {
870    log::trace!("{:?}: create gles3", glsl_file);
871    let mut compile_result = compile_glsl(compile_parameters, PREPROCESSOR_DEF_PLATFORM_GLES3)?;
872
873    let gles3_src = if let Some(src) = try_load_override_src(glsl_file, ".gles3")? {
874        src
875    } else {
876        let spirv_cross_module =
877            spirv_cross::spirv::Module::from_words(compile_result.unoptimized_spv.as_binary());
878
879        let mut gles3_ast =
880            spirv_cross::spirv::Ast::<spirv_cross::glsl::Target>::parse(&spirv_cross_module)?;
881        let mut spirv_cross_gles3_options = spirv_cross::glsl::CompilerOptions::default();
882        spirv_cross_gles3_options.version = spirv_cross::glsl::Version::V3_00Es;
883        spirv_cross_gles3_options.vulkan_semantics = false;
884        spirv_cross_gles3_options.vertex.transform_clip_space = true;
885        spirv_cross_gles3_options.vertex.invert_y = true;
886
887        let shader_resources = compile_result.ast.get_shader_resources()?;
888
889        rename_gl_samplers(&mut compile_result.reflection_data, &mut gles3_ast)?;
890        rename_gl_in_out_attributes(
891            compile_parameters.shader_kind,
892            &mut gles3_ast,
893            &shader_resources,
894        )?;
895
896        gles3_ast.set_compiler_options(&spirv_cross_gles3_options)?;
897        gles3_ast.compile()?
898    };
899
900    Ok(CrossCompileOutputGles3 {
901        gles3_src,
902        reflection_data: compile_result.reflection_data,
903    })
904}
905
906pub struct CrossCompileOutputGles2 {
907    gles2_src: String,
908    reflection_data: Option<ShaderProcessorRefectionData>,
909}
910
911fn cross_compile_to_gles2(
912    glsl_file: &Path,
913    compile_parameters: &CompileParameters,
914) -> Result<CrossCompileOutputGles2, Box<dyn Error>> {
915    log::trace!("{:?}: create gles2", glsl_file);
916    let mut compile_result = compile_glsl(compile_parameters, PREPROCESSOR_DEF_PLATFORM_GLES2)?;
917
918    let gles2_src = if let Some(src) = try_load_override_src(glsl_file, ".gles2")? {
919        src
920    } else {
921        let spirv_cross_module =
922            spirv_cross::spirv::Module::from_words(compile_result.unoptimized_spv.as_binary());
923
924        let mut gles2_ast =
925            spirv_cross::spirv::Ast::<spirv_cross::glsl::Target>::parse(&spirv_cross_module)?;
926        let mut spirv_cross_gles2_options = spirv_cross::glsl::CompilerOptions::default();
927        spirv_cross_gles2_options.version = spirv_cross::glsl::Version::V1_00Es;
928        spirv_cross_gles2_options.vulkan_semantics = false;
929        spirv_cross_gles2_options.vertex.transform_clip_space = true;
930        spirv_cross_gles2_options.vertex.invert_y = true;
931
932        let shader_resources = compile_result.ast.get_shader_resources()?;
933
934        // Rename uniform blocks to be consistent with how they would appear in GL ES 3.0. This way
935        // we can consistently use the same GL name across both backends
936        for resource in &shader_resources.uniform_buffers {
937            let block_name = gles2_ast.get_name(resource.base_type_id)?;
938            gles2_ast.set_name(
939                resource.base_type_id,
940                &format!("{}_UniformBlock", block_name),
941            )?;
942            gles2_ast.set_name(resource.id, &block_name)?;
943        }
944
945        rename_gl_samplers(&mut compile_result.reflection_data, &mut gles2_ast)?;
946        rename_gl_in_out_attributes(
947            compile_parameters.shader_kind,
948            &mut gles2_ast,
949            &shader_resources,
950        )?;
951
952        gles2_ast.set_compiler_options(&spirv_cross_gles2_options)?;
953        gles2_ast.compile()?
954    };
955
956    Ok(CrossCompileOutputGles2 {
957        gles2_src,
958        reflection_data: compile_result.reflection_data,
959    })
960}
961
962fn write_output_file<C: AsRef<[u8]>>(
963    path: &PathBuf,
964    contents: C,
965) -> std::io::Result<()> {
966    std::fs::create_dir_all(path.parent().unwrap())?;
967    std::fs::write(path, contents)
968}
969
970fn rename_gl_samplers(
971    reflected_data: &mut Option<ShaderProcessorRefectionData>,
972    ast: &mut Ast<Target>,
973) -> Result<(), Box<dyn Error>> {
974    ast.build_combined_image_samplers()?;
975
976    let mut all_combined_textures = FnvHashSet::default();
977    for remap in ast.get_combined_image_samplers()? {
978        let texture_name = ast.get_name(remap.image_id)?;
979        let sampler_name = ast.get_name(remap.sampler_id)?;
980
981        let already_sampled = !all_combined_textures.insert(remap.image_id);
982        if already_sampled {
983            Err(format!("The texture {} is being read by multiple samplers. This is not supported in GL ES 2.0", texture_name))?;
984        }
985
986        if let Some(reflected_data) = reflected_data {
987            reflected_data.set_gl_sampler_name(&texture_name, &sampler_name);
988        }
989
990        ast.set_name(remap.combined_id, &texture_name)?
991    }
992
993    Ok(())
994}
995
996fn rename_gl_in_out_attributes(
997    shader_kind: ShaderKind,
998    ast: &mut Ast<Target>,
999    shader_resources: &ShaderResources,
1000) -> Result<(), Box<dyn Error>> {
1001    if normalize_shader_kind(shader_kind) == ShaderKind::Vertex {
1002        for resource in &shader_resources.stage_outputs {
1003            let location =
1004                ast.get_decoration(resource.id, spirv_cross::spirv::Decoration::Location)?;
1005            ast.rename_interface_variable(
1006                &[resource.clone()],
1007                location,
1008                &format!("interface_var_{}", location),
1009            )?;
1010        }
1011    } else if normalize_shader_kind(shader_kind) == ShaderKind::Fragment {
1012        for resource in &shader_resources.stage_inputs {
1013            let location =
1014                ast.get_decoration(resource.id, spirv_cross::spirv::Decoration::Location)?;
1015            ast.rename_interface_variable(
1016                &[resource.clone()],
1017                location,
1018                &format!("interface_var_{}", location),
1019            )?;
1020        }
1021    }
1022
1023    Ok(())
1024}
1025
1026fn shader_kind_from_args(args: &ShaderProcessorArgs) -> Option<shaderc::ShaderKind> {
1027    let extensions = [
1028        ("vert", shaderc::ShaderKind::Vertex),
1029        ("frag", shaderc::ShaderKind::Fragment),
1030        ("tesc", shaderc::ShaderKind::TessControl),
1031        ("tese", shaderc::ShaderKind::TessEvaluation),
1032        ("geom", shaderc::ShaderKind::Geometry),
1033        ("comp", shaderc::ShaderKind::Compute),
1034        //("spvasm", shaderc::ShaderKind::Vertex), // we don't parse spvasm
1035        ("rgen", shaderc::ShaderKind::RayGeneration),
1036        ("rahit", shaderc::ShaderKind::AnyHit),
1037        ("rchit", shaderc::ShaderKind::ClosestHit),
1038        ("rmiss", shaderc::ShaderKind::Miss),
1039        ("rint", shaderc::ShaderKind::Intersection),
1040        ("rcall", shaderc::ShaderKind::Callable),
1041        ("task", shaderc::ShaderKind::Task),
1042        ("mesh", shaderc::ShaderKind::Mesh),
1043    ];
1044
1045    if let Some(shader_kind) = &args.shader_kind {
1046        for &(extension, kind) in &extensions {
1047            if shader_kind == extension {
1048                return Some(kind);
1049            }
1050        }
1051    }
1052
1053    None
1054}
1055
1056// based on https://github.com/google/shaderc/blob/caa519ca532a6a3a0279509fce2ceb791c4f4651/glslc/src/shader_stage.cc#L69
1057fn deduce_default_shader_kind_from_path(path: &Path) -> Option<shaderc::ShaderKind> {
1058    let extensions = [
1059        ("vert", shaderc::ShaderKind::DefaultVertex),
1060        ("frag", shaderc::ShaderKind::DefaultFragment),
1061        ("tesc", shaderc::ShaderKind::DefaultTessControl),
1062        ("tese", shaderc::ShaderKind::DefaultTessEvaluation),
1063        ("geom", shaderc::ShaderKind::DefaultGeometry),
1064        ("comp", shaderc::ShaderKind::DefaultCompute),
1065        //("spvasm", shaderc::ShaderKind::Vertex), // we don't parse spvasm
1066        ("rgen", shaderc::ShaderKind::DefaultRayGeneration),
1067        ("rahit", shaderc::ShaderKind::DefaultAnyHit),
1068        ("rchit", shaderc::ShaderKind::DefaultClosestHit),
1069        ("rmiss", shaderc::ShaderKind::DefaultMiss),
1070        ("rint", shaderc::ShaderKind::DefaultIntersection),
1071        ("rcall", shaderc::ShaderKind::DefaultCallable),
1072        ("task", shaderc::ShaderKind::DefaultTask),
1073        ("mesh", shaderc::ShaderKind::DefaultMesh),
1074    ];
1075
1076    if let Some(extension) = path.extension() {
1077        let as_str = extension.to_string_lossy();
1078
1079        for &(extension, kind) in &extensions {
1080            if as_str.contains(extension) {
1081                return Some(kind);
1082            }
1083        }
1084    }
1085
1086    None
1087}
1088
1089fn normalize_shader_kind(shader_kind: ShaderKind) -> ShaderKind {
1090    match shader_kind {
1091        ShaderKind::Vertex | ShaderKind::DefaultVertex => ShaderKind::Vertex,
1092        ShaderKind::Fragment | ShaderKind::DefaultFragment => ShaderKind::Fragment,
1093        ShaderKind::Compute | ShaderKind::DefaultCompute => ShaderKind::Compute,
1094        ShaderKind::Geometry | ShaderKind::DefaultGeometry => ShaderKind::Geometry,
1095        ShaderKind::TessControl | ShaderKind::DefaultTessControl => ShaderKind::TessControl,
1096        ShaderKind::TessEvaluation | ShaderKind::DefaultTessEvaluation => {
1097            ShaderKind::TessEvaluation
1098        }
1099        ShaderKind::RayGeneration | ShaderKind::DefaultRayGeneration => ShaderKind::RayGeneration,
1100        ShaderKind::AnyHit | ShaderKind::DefaultAnyHit => ShaderKind::AnyHit,
1101        ShaderKind::ClosestHit | ShaderKind::DefaultClosestHit => ShaderKind::ClosestHit,
1102        ShaderKind::Miss | ShaderKind::DefaultMiss => ShaderKind::Miss,
1103        ShaderKind::Intersection | ShaderKind::DefaultIntersection => ShaderKind::Intersection,
1104        ShaderKind::Callable | ShaderKind::DefaultCallable => ShaderKind::Callable,
1105        ShaderKind::Task | ShaderKind::DefaultTask => ShaderKind::Task,
1106        ShaderKind::Mesh | ShaderKind::DefaultMesh => ShaderKind::Mesh,
1107        ShaderKind::InferFromSource => ShaderKind::InferFromSource,
1108        ShaderKind::SpirvAssembly => ShaderKind::SpirvAssembly,
1109    }
1110}