use rustc_hash::FxHashMap as HashMap;
use std::path::{Path, PathBuf};
const GENERATED_HEADER_PREFIX: &str = "//! DO NOT EDIT: This shader file was generated by Ply!";
pub type FileTypeHandler = Box<dyn Fn(&Path, &Path) -> Vec<String>>;
pub struct ShaderBuild {
source_dir: PathBuf,
output_dir: PathBuf,
spirv_dir: PathBuf,
hash_file: PathBuf,
slangc_path: Option<PathBuf>,
custom_handlers: HashMap<String, FileTypeHandler>,
}
impl ShaderBuild {
pub fn new() -> Self {
Self {
source_dir: PathBuf::from("shaders/"),
output_dir: PathBuf::from("assets/build/shaders/"),
spirv_dir: PathBuf::from("build/shaders/spirv/"),
hash_file: PathBuf::from("build/shaders/hashes.json"),
slangc_path: None,
custom_handlers: HashMap::default(),
}
}
pub fn source_dir(mut self, dir: &str) -> Self {
self.source_dir = PathBuf::from(dir);
self
}
pub fn output_dir(mut self, dir: &str) -> Self {
self.output_dir = PathBuf::from(dir);
self
}
pub fn slangc_path(mut self, path: &str) -> Self {
self.slangc_path = Some(PathBuf::from(path));
self
}
pub fn override_file_type_handler(
mut self,
extension: &str,
handler: impl Fn(&Path, &Path) -> Vec<String> + 'static,
) -> Self {
let ext = if extension.starts_with('.') {
extension.to_string()
} else {
format!(".{}", extension)
};
self.custom_handlers.insert(ext, Box::new(handler));
self
}
pub fn build(self) {
println!("cargo:rerun-if-changed={}", self.source_dir.display());
std::fs::create_dir_all(&self.output_dir)
.unwrap_or_else(|e| panic!("Failed to create output dir '{}': {}", self.output_dir.display(), e));
std::fs::create_dir_all(&self.spirv_dir)
.unwrap_or_else(|e| panic!("Failed to create SPIR-V dir '{}': {}", self.spirv_dir.display(), e));
if let Some(parent) = self.hash_file.parent() {
std::fs::create_dir_all(parent)
.unwrap_or_else(|e| panic!("Failed to create hash dir '{}': {}", parent.display(), e));
}
let hashes = load_hashes(&self.hash_file);
let mut new_hashes = HashMap::default();
if !self.source_dir.exists() {
println!(
"cargo:warning=Shader source directory '{}' does not exist, skipping shader build",
self.source_dir.display()
);
return;
}
let shader_files = collect_shader_files(&self.source_dir);
if shader_files.is_empty() {
println!(
"cargo:warning=No shader files found in '{}'",
self.source_dir.display()
);
return;
}
for file_path in &shader_files {
let ext = file_path
.extension()
.and_then(|e| e.to_str())
.map(|e| format!(".{}", e))
.unwrap_or_default();
let rel_path = file_path
.strip_prefix(&self.source_dir)
.unwrap_or(file_path);
let source_hash = content_hash(file_path);
let dep_key = format!("{}:deps", rel_path.display());
let dep_hashes_changed = if let Some(_handler) = self.custom_handlers.get(&ext) {
let dep_globs = handler_dep_globs_cached(&hashes, &dep_key);
check_dep_hashes_changed(&dep_globs, &hashes)
} else {
false
};
let hash_key = rel_path.display().to_string();
let needs_rebuild = match hashes.get(&hash_key) {
Some(old_hash) => *old_hash != source_hash || dep_hashes_changed,
None => true,
};
if !needs_rebuild {
new_hashes.insert(hash_key, source_hash);
if let Some(deps) = hashes.get(&dep_key) {
new_hashes.insert(dep_key.clone(), deps.clone());
}
let dep_globs = handler_dep_globs_cached(&hashes, &dep_key);
for pattern in &dep_globs {
if let Ok(paths) = expand_glob(pattern) {
for path in paths {
let key = path.display().to_string();
if let Some(h) = hashes.get(&key) {
new_hashes.insert(key, h.clone());
}
}
}
}
continue;
}
println!("cargo:warning=Compiling shader: {}", rel_path.display());
let dep_globs = if let Some(handler) = self.custom_handlers.get(&ext) {
let globs = handler(file_path, &self.output_dir);
globs
} else {
match ext.as_str() {
".slang" | ".hlsl" => {
if !compile_slang(file_path, rel_path, &self.output_dir, &self.spirv_dir, self.slangc_path.as_deref()) {
continue;
}
vec![]
}
".glsl" | ".frag" => {
copy_glsl(file_path, rel_path, &self.output_dir);
vec![]
}
other => {
println!(
"cargo:warning=Unknown shader extension '{}' for file '{}', skipping",
other,
rel_path.display()
);
continue;
}
}
};
new_hashes.insert(hash_key, source_hash);
if !dep_globs.is_empty() {
let dep_json = format!(
"[{}]",
dep_globs
.iter()
.map(|g| format!("\"{}\"", g.replace('\\', "\\\\")))
.collect::<Vec<_>>()
.join(",")
);
new_hashes.insert(dep_key, dep_json);
for pattern in &dep_globs {
if let Ok(paths) = expand_glob(pattern) {
for path in paths {
if path.exists() {
let dep_hash = content_hash(&path);
new_hashes.insert(path.display().to_string(), dep_hash);
}
}
}
}
}
}
save_hashes(&self.hash_file, &new_hashes);
}
}
impl Default for ShaderBuild {
fn default() -> Self {
Self::new()
}
}
fn load_hashes(path: &Path) -> HashMap<String, String> {
if !path.exists() {
return HashMap::default();
}
let content = std::fs::read_to_string(path).unwrap_or_default();
parse_simple_json_map(&content)
}
fn save_hashes(path: &Path, hashes: &HashMap<String, String>) {
let mut entries: Vec<_> = hashes.iter().collect();
entries.sort_by_key(|(k, _)| (*k).clone());
let mut json = String::from("{\n");
for (i, (key, value)) in entries.iter().enumerate() {
json.push_str(&format!(
" \"{}\": \"{}\"",
escape_json(key),
escape_json(value)
));
if i < entries.len() - 1 {
json.push(',');
}
json.push('\n');
}
json.push('}');
std::fs::write(path, json)
.unwrap_or_else(|e| panic!("Failed to write hashes to '{}': {}", path.display(), e));
}
fn parse_simple_json_map(json: &str) -> HashMap<String, String> {
let mut map = HashMap::default();
let trimmed = json.trim();
if !trimmed.starts_with('{') || !trimmed.ends_with('}') {
return map;
}
let inner = &trimmed[1..trimmed.len() - 1];
let mut key = String::new();
let mut value = String::new();
let mut in_key = false;
let mut in_value = false;
let mut in_string = false;
let mut escape_next = false;
let mut after_colon = false;
for ch in inner.chars() {
if escape_next {
if in_key {
key.push(ch);
} else if in_value {
value.push(ch);
}
escape_next = false;
continue;
}
if ch == '\\' && in_string {
escape_next = true;
if in_key {
key.push(ch);
} else if in_value {
value.push(ch);
}
continue;
}
if ch == '"' {
if !in_string {
in_string = true;
if !after_colon {
in_key = true;
in_value = false;
} else {
in_value = true;
in_key = false;
}
} else {
in_string = false;
if in_value {
map.insert(key.clone(), value.clone());
key.clear();
value.clear();
in_key = false;
in_value = false;
after_colon = false;
}
if in_key {
in_key = false;
}
}
continue;
}
if ch == ':' && !in_string {
after_colon = true;
continue;
}
if ch == ',' && !in_string {
after_colon = false;
continue;
}
if in_key {
key.push(ch);
} else if in_value {
value.push(ch);
}
}
map
}
fn escape_json(s: &str) -> String {
s.replace('\\', "\\\\").replace('"', "\\\"")
}
fn content_hash(path: &Path) -> String {
let bytes = std::fs::read(path)
.unwrap_or_else(|e| panic!("Failed to read '{}': {}", path.display(), e));
let mut hash: u64 = 0xcbf29ce484222325;
for byte in &bytes {
hash ^= *byte as u64;
hash = hash.wrapping_mul(0x100000001b3);
}
format!("{:016x}", hash)
}
fn handler_dep_globs_cached(hashes: &HashMap<String, String>, dep_key: &str) -> Vec<String> {
match hashes.get(dep_key) {
Some(json_str) => parse_string_array(json_str),
None => vec![],
}
}
fn parse_string_array(json: &str) -> Vec<String> {
let trimmed = json.trim();
if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
return vec![];
}
let inner = &trimmed[1..trimmed.len() - 1];
let mut result = vec![];
let mut current = String::new();
let mut in_string = false;
let mut escape_next = false;
for ch in inner.chars() {
if escape_next {
current.push(ch);
escape_next = false;
continue;
}
if ch == '\\' && in_string {
escape_next = true;
continue;
}
if ch == '"' {
if in_string {
result.push(current.clone());
current.clear();
}
in_string = !in_string;
continue;
}
if in_string {
current.push(ch);
}
}
result
}
fn check_dep_hashes_changed(dep_globs: &[String], hashes: &HashMap<String, String>) -> bool {
for pattern in dep_globs {
if let Ok(paths) = expand_glob(pattern) {
for path in paths {
if !path.exists() {
continue;
}
let current_hash = content_hash(&path);
let hash_key = path.display().to_string();
match hashes.get(&hash_key) {
Some(old_hash) if *old_hash == current_hash => {}
_ => return true, }
}
}
}
false
}
fn expand_glob(pattern: &str) -> Result<Vec<PathBuf>, std::io::Error> {
let mut results = vec![];
let parts: Vec<&str> = pattern.split('/').collect();
expand_glob_recursive(Path::new("."), &parts, 0, &mut results)?;
Ok(results)
}
fn expand_glob_recursive(
base: &Path,
parts: &[&str],
idx: usize,
results: &mut Vec<PathBuf>,
) -> Result<(), std::io::Error> {
if idx >= parts.len() {
if base.is_file() {
results.push(base.to_path_buf());
}
return Ok(());
}
let part = parts[idx];
if part == "**" {
expand_glob_recursive(base, parts, idx + 1, results)?;
if base.is_dir() {
for entry in std::fs::read_dir(base)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
expand_glob_recursive(&path, parts, idx, results)?;
}
}
}
} else if part.contains('*') {
if base.is_dir() {
for entry in std::fs::read_dir(base)? {
let entry = entry?;
let name = entry.file_name();
let name_str = name.to_string_lossy();
if matches_wildcard(part, &name_str) {
expand_glob_recursive(&entry.path(), parts, idx + 1, results)?;
}
}
}
} else {
let next = base.join(part);
if next.exists() {
expand_glob_recursive(&next, parts, idx + 1, results)?;
}
}
Ok(())
}
fn matches_wildcard(pattern: &str, name: &str) -> bool {
if let Some(suffix) = pattern.strip_prefix('*') {
name.ends_with(suffix)
} else if let Some(prefix) = pattern.strip_suffix('*') {
name.starts_with(prefix)
} else {
pattern == name
}
}
fn collect_shader_files(dir: &Path) -> Vec<PathBuf> {
let mut files = vec![];
collect_shader_files_recursive(dir, &mut files);
files.sort();
files
}
fn collect_shader_files_recursive(dir: &Path, files: &mut Vec<PathBuf>) {
let entries = match std::fs::read_dir(dir) {
Ok(e) => e,
Err(_) => return,
};
for entry in entries {
let Ok(entry) = entry else { continue };
let path = entry.path();
if path.is_dir() {
collect_shader_files_recursive(&path, files);
} else if path.is_file() {
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
match ext {
"slang" | "hlsl" | "glsl" | "frag" => {
files.push(path);
}
_ => {}
}
}
}
}
}
fn compile_slang(source: &Path, rel_path: &Path, _output_dir: &Path, spirv_dir: &Path, slangc_path: Option<&Path>) -> bool {
let stem = rel_path.file_stem().unwrap().to_string_lossy();
let spv_path = spirv_dir.join(format!("{}.spv", stem));
let slangc_cmd = slangc_path.map(|p| p.as_os_str().to_owned()).unwrap_or_else(|| std::ffi::OsString::from("slangc"));
let slangc_status = std::process::Command::new(&slangc_cmd)
.arg(source)
.arg("-target")
.arg("spirv")
.arg("-entry")
.arg("main")
.arg("-stage")
.arg("fragment")
.arg("-o")
.arg(&spv_path)
.status();
match slangc_status {
Ok(status) if status.success() => {}
Ok(status) => {
println!(
"cargo:warning=slangc failed with exit code {} for '{}' — skipping",
status,
source.display()
);
return false;
}
Err(e) => {
println!(
"cargo:warning=Could not run slangc for '{}': {}. Is slangc installed and on PATH? Skipping.",
source.display(),
e
);
return false;
}
}
#[cfg(feature = "shader-build")]
{
let output_path = _output_dir.join(format!("{}.frag.glsl", stem));
if !spirv_cross_library(&spv_path, &output_path) {
return false;
} else {
prepend_header(&output_path, &format!("{}", rel_path.display()));
return true;
};
}
#[cfg(not(feature = "shader-build"))]
{
println!("cargo:warning=The 'shader-build' feature is not enabled, but needed for SPIR-V to GLSL conversion. Please enable the 'shader-build' feature in your [dev-dependencies]!");
return false;
}
}
#[cfg(feature = "shader-build")]
fn spirv_cross_library(spv_path: &Path, output_path: &Path) -> bool {
use spirv_cross2::compile::glsl::GlslVersion;
use spirv_cross2::compile::CompilableTarget;
use spirv_cross2::targets::Glsl;
use spirv_cross2::{Compiler, Module};
let spv_bytes = std::fs::read(spv_path)
.unwrap_or_else(|e| panic!("Failed to read SPIR-V file '{}': {}", spv_path.display(), e));
assert!(
spv_bytes.len() % 4 == 0,
"SPIR-V file size must be a multiple of 4 bytes"
);
let spv_words: Vec<u32> = spv_bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
let module = Module::from_words(&spv_words);
let compiler = match Compiler::<Glsl>::new(module) {
Ok(c) => c,
Err(e) => {
println!(
"cargo:warning=spirv-cross2 failed to create GLSL compiler for '{}': {} — skipping",
spv_path.display(),
e
);
return false;
}
};
let mut options = Glsl::options();
options.version = GlslVersion::Glsl100Es;
let artifact = match compiler.compile(&options) {
Ok(a) => a,
Err(e) => {
println!(
"cargo:warning=spirv-cross2 failed to compile to GLSL: {} — skipping",
e
);
return false;
}
};
let glsl_source = artifact.to_string();
std::fs::write(output_path, glsl_source)
.unwrap_or_else(|e| panic!("Failed to write '{}': {}", output_path.display(), e));
true
}
fn copy_glsl(source: &Path, rel_path: &Path, output_dir: &Path) {
let stem = rel_path.file_stem().unwrap().to_string_lossy();
let output_path = output_dir.join(format!("{}.frag.glsl", stem));
let content = std::fs::read_to_string(source)
.unwrap_or_else(|e| panic!("Failed to read '{}': {}", source.display(), e));
let header = format!(
"{}\n//! Source: {}\n",
GENERATED_HEADER_PREFIX,
rel_path.display()
);
std::fs::write(&output_path, format!("{}{}", header, content))
.unwrap_or_else(|e| panic!("Failed to write '{}': {}", output_path.display(), e));
}
#[cfg(feature = "shader-build")]
fn prepend_header(path: &Path, source_path: &str) {
let content = std::fs::read_to_string(path)
.unwrap_or_else(|e| panic!("Failed to read '{}': {}", path.display(), e));
let header = format!(
"{}\n//! Source: {}\n",
GENERATED_HEADER_PREFIX, source_path
);
std::fs::write(path, format!("{}{}", header, content))
.unwrap_or_else(|e| panic!("Failed to write '{}': {}", path.display(), e));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_content_hash_deterministic() {
let dir = std::env::temp_dir().join("ply_shader_build_test");
std::fs::create_dir_all(&dir).unwrap();
let file = dir.join("test.glsl");
std::fs::write(&file, "void main() {}").unwrap();
let h1 = content_hash(&file);
let h2 = content_hash(&file);
assert_eq!(h1, h2);
std::fs::write(&file, "void main() { gl_FragColor = vec4(1.0); }").unwrap();
let h3 = content_hash(&file);
assert_ne!(h1, h3);
std::fs::remove_dir_all(&dir).unwrap();
}
#[test]
fn test_parse_simple_json_map() {
let json = r#"{ "foo": "bar", "baz": "qux" }"#;
let map = parse_simple_json_map(json);
assert_eq!(map.get("foo").unwrap(), "bar");
assert_eq!(map.get("baz").unwrap(), "qux");
}
#[test]
fn test_parse_string_array() {
let json = r#"["a","b","c"]"#;
let arr = parse_string_array(json);
assert_eq!(arr, vec!["a", "b", "c"]);
}
#[test]
fn test_hash_round_trip() {
let dir = std::env::temp_dir().join("ply_shader_hash_test");
std::fs::create_dir_all(&dir).unwrap();
let hash_file = dir.join("hashes.json");
let mut hashes = HashMap::default();
hashes.insert("foo.slang".to_string(), "abcdef0123456789".to_string());
hashes.insert("bar.glsl".to_string(), "9876543210fedcba".to_string());
save_hashes(&hash_file, &hashes);
let loaded = load_hashes(&hash_file);
assert_eq!(loaded.get("foo.slang").unwrap(), "abcdef0123456789");
assert_eq!(loaded.get("bar.glsl").unwrap(), "9876543210fedcba");
std::fs::remove_dir_all(&dir).unwrap();
}
#[test]
fn test_matches_wildcard() {
assert!(matches_wildcard("*.glsl", "test.glsl"));
assert!(matches_wildcard("*.glsl", "foo.glsl"));
assert!(!matches_wildcard("*.glsl", "test.slang"));
assert!(matches_wildcard("test*", "test.glsl"));
assert!(matches_wildcard("test*", "test_file"));
}
}