use crate::{entry_point, read_file_to_string, structs, RegisteredType, TypesMeta};
use ahash::HashMap;
use proc_macro2::TokenStream;
pub use shaderc::{CompilationArtifact, IncludeType, ResolvedInclude, ShaderKind};
use shaderc::{CompileOptions, Compiler, EnvVersion, SpirvVersion, TargetEnv};
use std::{
cell::{RefCell, RefMut},
io::Error as IoError,
iter::Iterator,
path::Path,
};
use vulkano::shader::{
reflect,
spirv::{Spirv, SpirvError},
};
pub(super) fn path_to_str(path: &Path) -> &str {
path.to_str().expect(
"Could not stringify the file to be included. Make sure the path consists of \
valid unicode characters.",
)
}
#[allow(clippy::too_many_arguments)]
fn include_callback(
requested_source_path_raw: &str,
directive_type: IncludeType,
contained_within_path_raw: &str,
recursion_depth: usize,
include_directories: &[impl AsRef<Path>],
root_source_has_path: bool,
base_path: &impl AsRef<Path>,
mut includes_tracker: RefMut<'_, Vec<String>>,
) -> Result<ResolvedInclude, String> {
let file_to_include = match directive_type {
IncludeType::Relative => {
let requested_source_path = Path::new(requested_source_path_raw);
if !root_source_has_path && recursion_depth == 1 && !requested_source_path.is_absolute()
{
let requested_source_name = requested_source_path
.file_name()
.expect("Could not get the name of the requested source file.")
.to_string_lossy();
let requested_source_directory = requested_source_path
.parent()
.expect("Could not get the directory of the requested source file.")
.to_string_lossy();
return Err(format!(
"Usage of relative paths in imports in embedded GLSL is not \
allowed, try using `#include <{}>` and adding the directory \
`{}` to the `include` array in your `shader!` macro call \
instead.",
requested_source_name, requested_source_directory
));
}
let mut resolved_path = if recursion_depth == 1 {
Path::new(contained_within_path_raw)
.parent()
.map(|parent| base_path.as_ref().join(parent))
} else {
Path::new(contained_within_path_raw)
.parent()
.map(|parent| parent.to_owned())
}
.unwrap_or_else(|| {
panic!(
"The file `{}` does not reside in a directory. This is \
an implementation error.",
contained_within_path_raw
)
});
resolved_path.push(requested_source_path);
if !resolved_path.is_file() {
return Err(format!(
"Invalid inclusion path `{}`, the path does not point to a file.",
requested_source_path_raw
));
}
resolved_path
}
IncludeType::Standard => {
let requested_source_path = Path::new(requested_source_path_raw);
if requested_source_path.is_absolute() {
return Err(format!(
"No such file found, as specified by the absolute path. \
Keep in mind, that absolute paths cannot be used with \
inclusion from standard directories (`#include <...>`), try \
using `#include \"...\"` instead. Requested path: {}",
requested_source_path_raw
));
}
let found_requested_source_path = include_directories
.iter()
.map(|include_directory| include_directory.as_ref().join(requested_source_path))
.find(|resolved_requested_source_path| resolved_requested_source_path.is_file());
if let Some(found_requested_source_path) = found_requested_source_path {
found_requested_source_path
} else {
return Err(format!(
"Could not include the file `{}` from any include directories.",
requested_source_path_raw
));
}
}
};
let file_to_include_string = path_to_str(file_to_include.as_path()).to_string();
let content = read_file_to_string(file_to_include.as_path()).map_err(|_| {
format!(
"Could not read the contents of file `{}` to be included in the \
shader source.",
&file_to_include_string
)
})?;
includes_tracker.push(file_to_include_string.clone());
Ok(ResolvedInclude {
resolved_name: file_to_include_string,
content,
})
}
#[allow(clippy::too_many_arguments)]
pub fn compile(
path: Option<String>,
base_path: &impl AsRef<Path>,
code: &str,
ty: ShaderKind,
include_directories: &[impl AsRef<Path>],
macro_defines: &[(impl AsRef<str>, impl AsRef<str>)],
vulkan_version: Option<EnvVersion>,
spirv_version: Option<SpirvVersion>,
) -> Result<(CompilationArtifact, Vec<String>), String> {
let includes_tracker = RefCell::new(Vec::new());
let compiler = Compiler::new().ok_or("failed to create GLSL compiler")?;
let mut compile_options = CompileOptions::new().ok_or("failed to initialize compile option")?;
compile_options.set_target_env(
TargetEnv::Vulkan,
vulkan_version.unwrap_or(EnvVersion::Vulkan1_0) as u32,
);
if let Some(spirv_version) = spirv_version {
compile_options.set_target_spirv(spirv_version);
}
let root_source_path = if let &Some(ref path) = &path {
path
} else {
"shader.glsl"
};
compile_options.set_include_callback(
|requested_source_path, directive_type, contained_within_path, recursion_depth| {
include_callback(
requested_source_path,
directive_type,
contained_within_path,
recursion_depth,
include_directories,
path.is_some(),
base_path,
includes_tracker.borrow_mut(),
)
},
);
for (macro_name, macro_value) in macro_defines.iter() {
compile_options.add_macro_definition(macro_name.as_ref(), Some(macro_value.as_ref()));
}
#[cfg(feature = "shaderc-debug")]
compile_options.set_generate_debug_info();
let content = compiler
.compile_into_spirv(code, ty, root_source_path, "main", Some(&compile_options))
.map_err(|e| e.to_string())?;
let includes = includes_tracker.borrow().clone();
Ok((content, includes))
}
pub(super) fn reflect<'a>(
prefix: &'a str,
words: &[u32],
types_meta: &TypesMeta,
input_paths: impl IntoIterator<Item = &'a str>,
shared_constants: bool,
types_registry: &'a mut HashMap<String, RegisteredType>,
) -> Result<(TokenStream, TokenStream), Error> {
let spirv = Spirv::new(words)?;
let include_bytes = input_paths.into_iter().map(|s| {
quote! {
::std::include_bytes!( #s )
}
});
let spirv_version = {
let major = spirv.version().major;
let minor = spirv.version().minor;
let patch = spirv.version().patch;
quote! {
::vulkano::Version {
major: #major,
minor: #minor,
patch: #patch,
}
}
};
let spirv_capabilities = reflect::spirv_capabilities(&spirv).map(|capability| {
let name = format_ident!("{}", format!("{:?}", capability));
quote! { &::vulkano::shader::spirv::Capability::#name }
});
let spirv_extensions = reflect::spirv_extensions(&spirv);
let entry_points = reflect::entry_points(&spirv)
.map(|(name, model, info)| entry_point::write_entry_point(&name, model, &info));
let specialization_constants =
structs::write_specialization_constants(prefix, &spirv, shared_constants, types_registry);
let load_name = if prefix.is_empty() {
format_ident!("load")
} else {
format_ident!("load_{}", prefix)
};
let shader_code = quote! {
#[inline]
#[allow(unsafe_code)]
pub fn #load_name(device: ::std::sync::Arc<::vulkano::device::Device>)
-> Result<::std::sync::Arc<::vulkano::shader::ShaderModule>, ::vulkano::shader::ShaderCreationError>
{
let _bytes = ( #( #include_bytes),* );
static WORDS: &[u32] = &[ #( #words ),* ];
unsafe {
::vulkano::shader::ShaderModule::from_words_with_data(
device,
WORDS,
#spirv_version,
[#(#spirv_capabilities),*],
[#(#spirv_extensions),*],
[#(#entry_points),*],
)
}
}
#specialization_constants
};
let structs = structs::write_structs(prefix, &spirv, types_meta, types_registry);
Ok((shader_code, structs))
}
#[derive(Debug)]
pub enum Error {
IoError(IoError),
SpirvError(SpirvError),
}
impl From<IoError> for Error {
fn from(err: IoError) -> Error {
Error::IoError(err)
}
}
impl From<SpirvError> for Error {
fn from(err: SpirvError) -> Error {
Error::SpirvError(err)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codegen::compile;
use shaderc::ShaderKind;
use std::path::{Path, PathBuf};
use vulkano::shader::{reflect, spirv::Spirv};
#[cfg(not(target_os = "windows"))]
pub fn path_separator() -> &'static str {
"/"
}
#[cfg(target_os = "windows")]
pub fn path_separator() -> &'static str {
"\\"
}
fn convert_paths(root_path: &Path, paths: &[String]) -> Vec<String> {
paths
.iter()
.map(|p| path_to_str(root_path.join(p).as_path()).to_owned())
.collect()
}
#[test]
fn spirv_parse() {
let data = include_bytes!("../tests/frag.spv");
let insts: Vec<_> = data
.chunks(4)
.map(|c| {
((c[3] as u32) << 24) | ((c[2] as u32) << 16) | ((c[1] as u32) << 8) | c[0] as u32
})
.collect();
Spirv::new(&insts).unwrap();
}
#[test]
fn test_bad_alignment() {
let includes: [PathBuf; 0] = [];
let defines: [(String, String); 0] = [];
let (comp, _) = compile(
None,
&Path::new(""),
"
#version 450
struct MyStruct {
vec3 vs[2];
};
layout(binding=0) uniform UBO {
MyStruct s;
};
void main() {}
",
ShaderKind::Vertex,
&includes,
&defines,
None,
None,
)
.unwrap();
let spirv = Spirv::new(comp.as_binary()).unwrap();
let res = std::panic::catch_unwind(|| {
structs::write_structs("", &spirv, &TypesMeta::default(), &mut HashMap::default())
});
assert!(res.is_err());
}
#[test]
fn test_trivial_alignment() {
let includes: [PathBuf; 0] = [];
let defines: [(String, String); 0] = [];
let (comp, _) = compile(
None,
&Path::new(""),
"
#version 450
struct MyStruct {
vec4 vs[2];
};
layout(binding=0) uniform UBO {
MyStruct s;
};
void main() {}
",
ShaderKind::Vertex,
&includes,
&defines,
None,
None,
)
.unwrap();
let spirv = Spirv::new(comp.as_binary()).unwrap();
structs::write_structs("", &spirv, &TypesMeta::default(), &mut HashMap::default());
}
#[test]
fn test_wrap_alignment() {
let includes: [PathBuf; 0] = [];
let defines: [(String, String); 0] = [];
let (comp, _) = compile(
None,
&Path::new(""),
"
#version 450
struct Vec3Wrap {
vec3 v;
};
struct MyStruct {
Vec3Wrap vs[2];
};
layout(binding=0) uniform UBO {
MyStruct s;
};
void main() {}
",
ShaderKind::Vertex,
&includes,
&defines,
None,
None,
)
.unwrap();
let spirv = Spirv::new(comp.as_binary()).unwrap();
structs::write_structs("", &spirv, &TypesMeta::default(), &mut HashMap::default());
}
#[test]
fn test_include_resolution() {
let root_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let empty_includes: [PathBuf; 0] = [];
let defines: [(String, String); 0] = [];
let (_compile_relative, _) = compile(
Some(String::from("tests/include_test.glsl")),
&root_path,
"
#version 450
#include \"include_dir_a/target_a.glsl\"
#include \"include_dir_b/target_b.glsl\"
void main() {}
",
ShaderKind::Vertex,
&empty_includes,
&defines,
None,
None,
)
.expect("Cannot resolve include files");
let (_compile_include_paths, includes) = compile(
Some(String::from("tests/include_test.glsl")),
&root_path,
"
#version 450
#include <target_a.glsl>
#include <target_b.glsl>
void main() {}
",
ShaderKind::Vertex,
&[
root_path.join("tests").join("include_dir_a"),
root_path.join("tests").join("include_dir_b"),
],
&defines,
None,
None,
)
.expect("Cannot resolve include files");
assert_eq!(
includes,
convert_paths(
&root_path,
&[
vec!["tests", "include_dir_a", "target_a.glsl"].join(path_separator()),
vec!["tests", "include_dir_b", "target_b.glsl"].join(path_separator()),
]
)
);
let (_compile_include_paths_with_relative, includes_with_relative) = compile(
Some(String::from("tests/include_test.glsl")),
&root_path,
"
#version 450
#include <target_a.glsl>
#include <../include_dir_b/target_b.glsl>
void main() {}
",
ShaderKind::Vertex,
&[root_path.join("tests").join("include_dir_a")],
&defines,
None,
None,
)
.expect("Cannot resolve include files");
assert_eq!(
includes_with_relative,
convert_paths(
&root_path,
&[
vec!["tests", "include_dir_a", "target_a.glsl"].join(path_separator()),
vec!["tests", "include_dir_a", "../include_dir_b/target_b.glsl"]
.join(path_separator()),
]
)
);
let absolute_path = root_path
.join("tests")
.join("include_dir_a")
.join("target_a.glsl");
let absolute_path_str = absolute_path
.to_str()
.expect("Cannot run tests in a folder with non unicode characters");
let (_compile_absolute_path, includes_absolute_path) = compile(
Some(String::from("tests/include_test.glsl")),
&root_path,
&format!(
"
#version 450
#include \"{}\"
void main() {{}}
",
absolute_path_str
),
ShaderKind::Vertex,
&empty_includes,
&defines,
None,
None,
)
.expect("Cannot resolve include files");
assert_eq!(
includes_absolute_path,
convert_paths(
&root_path,
&[vec!["tests", "include_dir_a", "target_a.glsl"].join(path_separator())]
)
);
let (_compile_recursive_, includes_recursive) = compile(
Some(String::from("tests/include_test.glsl")),
&root_path,
"
#version 450
#include <target_c.glsl>
void main() {}
",
ShaderKind::Vertex,
&[
root_path.join("tests").join("include_dir_b"),
root_path.join("tests").join("include_dir_c"),
],
&defines,
None,
None,
)
.expect("Cannot resolve include files");
assert_eq!(
includes_recursive,
convert_paths(
&root_path,
&[
vec!["tests", "include_dir_c", "target_c.glsl"].join(path_separator()),
vec!["tests", "include_dir_c", "../include_dir_a/target_a.glsl"]
.join(path_separator()),
vec!["tests", "include_dir_b", "target_b.glsl"].join(path_separator()),
]
)
);
}
#[test]
fn test_macros() {
let empty_includes: [PathBuf; 0] = [];
let defines = vec![("NAME1", ""), ("NAME2", "58")];
let no_defines: [(String, String); 0] = [];
let need_defines = "
#version 450
#if defined(NAME1) && NAME2 > 29
void main() {}
#endif
";
let compile_no_defines = compile(
None,
&Path::new(""),
need_defines,
ShaderKind::Vertex,
&empty_includes,
&no_defines,
None,
None,
);
assert!(compile_no_defines.is_err());
let compile_defines = compile(
None,
&Path::new(""),
need_defines,
ShaderKind::Vertex,
&empty_includes,
&defines,
None,
None,
);
compile_defines.expect("Setting shader macros did not work");
}
#[test]
fn test_descriptor_calculation_with_multiple_entrypoints() {
let data = include_bytes!("../tests/multiple_entrypoints.spv");
let instructions: Vec<u32> = data
.chunks(4)
.map(|c| {
((c[3] as u32) << 24) | ((c[2] as u32) << 16) | ((c[1] as u32) << 8) | c[0] as u32
})
.collect();
let spirv = Spirv::new(&instructions).unwrap();
let mut descriptors = Vec::new();
for (_, _, info) in reflect::entry_points(&spirv) {
descriptors.push(info.descriptor_requirements);
}
let e1_descriptors = descriptors.get(0).expect("Could not find entrypoint1");
let mut e1_bindings = Vec::new();
for loc in e1_descriptors.keys() {
e1_bindings.push(*loc);
}
assert_eq!(e1_bindings.len(), 5);
assert!(e1_bindings.contains(&(0, 0)));
assert!(e1_bindings.contains(&(0, 1)));
assert!(e1_bindings.contains(&(0, 2)));
assert!(e1_bindings.contains(&(0, 3)));
assert!(e1_bindings.contains(&(0, 4)));
let e2_descriptors = descriptors.get(1).expect("Could not find entrypoint2");
let mut e2_bindings = Vec::new();
for loc in e2_descriptors.keys() {
e2_bindings.push(*loc);
}
assert_eq!(e2_bindings.len(), 3);
assert!(e2_bindings.contains(&(0, 0)));
assert!(e2_bindings.contains(&(0, 1)));
assert!(e2_bindings.contains(&(0, 2)));
}
#[test]
fn test_descriptor_calculation_with_multiple_functions() {
let includes: [PathBuf; 0] = [];
let defines: [(String, String); 0] = [];
let (comp, _) = compile(
None,
&Path::new(""),
"
#version 450
layout(set = 1, binding = 0) buffer Buffer {
vec3 data;
} bo;
layout(set = 2, binding = 0) uniform Uniform {
float data;
} ubo;
layout(set = 3, binding = 1) uniform sampler textureSampler;
layout(set = 3, binding = 2) uniform texture2D imageTexture;
float withMagicSparkles(float data) {
return texture(sampler2D(imageTexture, textureSampler), vec2(data, data)).x;
}
vec3 makeSecretSauce() {
return vec3(withMagicSparkles(ubo.data));
}
void main() {
bo.data = makeSecretSauce();
}
",
ShaderKind::Vertex,
&includes,
&defines,
None,
None,
)
.unwrap();
let spirv = Spirv::new(comp.as_binary()).unwrap();
if let Some((_, _, info)) = reflect::entry_points(&spirv).next() {
let mut bindings = Vec::new();
for (loc, _reqs) in info.descriptor_requirements {
bindings.push(loc);
}
assert_eq!(bindings.len(), 4);
assert!(bindings.contains(&(1, 0)));
assert!(bindings.contains(&(2, 0)));
assert!(bindings.contains(&(3, 1)));
assert!(bindings.contains(&(3, 2)));
return;
}
panic!("Could not find entrypoint");
}
}