1use rustc_hash::FxHashMap as HashMap;
2use std::path::{Path, PathBuf};
3
4const GENERATED_HEADER_PREFIX: &str = "//! DO NOT EDIT: This shader file was generated by Ply!";
5
6pub type FileTypeHandler = Box<dyn Fn(&Path, &Path) -> Vec<String>>;
13
14pub struct ShaderBuild {
19 source_dir: PathBuf,
20 output_dir: PathBuf,
21 spirv_dir: PathBuf,
22 hash_file: PathBuf,
23 slangc_path: Option<PathBuf>,
24 custom_handlers: HashMap<String, FileTypeHandler>,
25}
26
27impl ShaderBuild {
28 pub fn new() -> Self {
36 Self {
37 source_dir: PathBuf::from("shaders/"),
38 output_dir: PathBuf::from("assets/build/shaders/"),
39 spirv_dir: PathBuf::from("build/shaders/spirv/"),
40 hash_file: PathBuf::from("build/shaders/hashes.json"),
41 slangc_path: None,
42 custom_handlers: HashMap::default(),
43 }
44 }
45
46 pub fn source_dir(mut self, dir: &str) -> Self {
48 self.source_dir = PathBuf::from(dir);
49 self
50 }
51
52 pub fn output_dir(mut self, dir: &str) -> Self {
54 self.output_dir = PathBuf::from(dir);
55 self
56 }
57
58 pub fn slangc_path(mut self, path: &str) -> Self {
72 self.slangc_path = Some(PathBuf::from(path));
73 self
74 }
75
76 pub fn override_file_type_handler(
95 mut self,
96 extension: &str,
97 handler: impl Fn(&Path, &Path) -> Vec<String> + 'static,
98 ) -> Self {
99 let ext = if extension.starts_with('.') {
100 extension.to_string()
101 } else {
102 format!(".{}", extension)
103 };
104 self.custom_handlers.insert(ext, Box::new(handler));
105 self
106 }
107
108 pub fn build(self) {
117 println!("cargo:rerun-if-changed={}", self.source_dir.display());
119
120 std::fs::create_dir_all(&self.output_dir)
122 .unwrap_or_else(|e| panic!("Failed to create output dir '{}': {}", self.output_dir.display(), e));
123 std::fs::create_dir_all(&self.spirv_dir)
124 .unwrap_or_else(|e| panic!("Failed to create SPIR-V dir '{}': {}", self.spirv_dir.display(), e));
125 if let Some(parent) = self.hash_file.parent() {
126 std::fs::create_dir_all(parent)
127 .unwrap_or_else(|e| panic!("Failed to create hash dir '{}': {}", parent.display(), e));
128 }
129
130 let hashes = load_hashes(&self.hash_file);
132 let mut new_hashes = HashMap::default();
133
134 if !self.source_dir.exists() {
136 println!(
137 "cargo:warning=Shader source directory '{}' does not exist, skipping shader build",
138 self.source_dir.display()
139 );
140 return;
141 }
142
143 let shader_files = collect_shader_files(&self.source_dir);
144 if shader_files.is_empty() {
145 println!(
146 "cargo:warning=No shader files found in '{}'",
147 self.source_dir.display()
148 );
149 return;
150 }
151
152 for file_path in &shader_files {
153 let ext = file_path
154 .extension()
155 .and_then(|e| e.to_str())
156 .map(|e| format!(".{}", e))
157 .unwrap_or_default();
158
159 let rel_path = file_path
160 .strip_prefix(&self.source_dir)
161 .unwrap_or(file_path);
162
163 let source_hash = content_hash(file_path);
165
166 let dep_key = format!("{}:deps", rel_path.display());
168 let dep_hashes_changed = if let Some(_handler) = self.custom_handlers.get(&ext) {
169 let dep_globs = handler_dep_globs_cached(&hashes, &dep_key);
170 check_dep_hashes_changed(&dep_globs, &hashes)
171 } else {
172 false
173 };
174
175 let hash_key = rel_path.display().to_string();
177 let needs_rebuild = match hashes.get(&hash_key) {
178 Some(old_hash) => *old_hash != source_hash || dep_hashes_changed,
179 None => true,
180 };
181
182 if !needs_rebuild {
183 new_hashes.insert(hash_key, source_hash);
185 if let Some(deps) = hashes.get(&dep_key) {
186 new_hashes.insert(dep_key.clone(), deps.clone());
187 }
188 let dep_globs = handler_dep_globs_cached(&hashes, &dep_key);
190 for pattern in &dep_globs {
191 if let Ok(paths) = expand_glob(pattern) {
192 for path in paths {
193 let key = path.display().to_string();
194 if let Some(h) = hashes.get(&key) {
195 new_hashes.insert(key, h.clone());
196 }
197 }
198 }
199 }
200 continue;
201 }
202
203 println!("cargo:warning=Compiling shader: {}", rel_path.display());
204
205 let dep_globs = if let Some(handler) = self.custom_handlers.get(&ext) {
207 let globs = handler(file_path, &self.output_dir);
209 globs
210 } else {
211 match ext.as_str() {
212 ".slang" | ".hlsl" => {
213 if !compile_slang(file_path, rel_path, &self.output_dir, &self.spirv_dir, self.slangc_path.as_deref()) {
214 continue;
215 }
216 vec![]
217 }
218 ".glsl" | ".frag" => {
219 copy_glsl(file_path, rel_path, &self.output_dir);
220 vec![]
221 }
222 other => {
223 println!(
224 "cargo:warning=Unknown shader extension '{}' for file '{}', skipping",
225 other,
226 rel_path.display()
227 );
228 continue;
229 }
230 }
231 };
232
233 new_hashes.insert(hash_key, source_hash);
234 if !dep_globs.is_empty() {
235 let dep_json = format!(
237 "[{}]",
238 dep_globs
239 .iter()
240 .map(|g| format!("\"{}\"", g.replace('\\', "\\\\")))
241 .collect::<Vec<_>>()
242 .join(",")
243 );
244 new_hashes.insert(dep_key, dep_json);
245
246 for pattern in &dep_globs {
248 if let Ok(paths) = expand_glob(pattern) {
249 for path in paths {
250 if path.exists() {
251 let dep_hash = content_hash(&path);
252 new_hashes.insert(path.display().to_string(), dep_hash);
253 }
254 }
255 }
256 }
257 }
258 }
259
260 save_hashes(&self.hash_file, &new_hashes);
262 }
263}
264
265impl Default for ShaderBuild {
266 fn default() -> Self {
267 Self::new()
268 }
269}
270
271fn load_hashes(path: &Path) -> HashMap<String, String> {
272 if !path.exists() {
273 return HashMap::default();
274 }
275 let content = std::fs::read_to_string(path).unwrap_or_default();
276 parse_simple_json_map(&content)
277}
278
279fn save_hashes(path: &Path, hashes: &HashMap<String, String>) {
280 let mut entries: Vec<_> = hashes.iter().collect();
281 entries.sort_by_key(|(k, _)| (*k).clone());
282
283 let mut json = String::from("{\n");
284 for (i, (key, value)) in entries.iter().enumerate() {
285 json.push_str(&format!(
286 " \"{}\": \"{}\"",
287 escape_json(key),
288 escape_json(value)
289 ));
290 if i < entries.len() - 1 {
291 json.push(',');
292 }
293 json.push('\n');
294 }
295 json.push('}');
296
297 std::fs::write(path, json)
298 .unwrap_or_else(|e| panic!("Failed to write hashes to '{}': {}", path.display(), e));
299}
300
301fn parse_simple_json_map(json: &str) -> HashMap<String, String> {
303 let mut map = HashMap::default();
304 let trimmed = json.trim();
305 if !trimmed.starts_with('{') || !trimmed.ends_with('}') {
306 return map;
307 }
308 let inner = &trimmed[1..trimmed.len() - 1];
309 let mut key = String::new();
311 let mut value = String::new();
312 let mut in_key = false;
313 let mut in_value = false;
314 let mut in_string = false;
315 let mut escape_next = false;
316 let mut after_colon = false;
317
318 for ch in inner.chars() {
319 if escape_next {
320 if in_key {
321 key.push(ch);
322 } else if in_value {
323 value.push(ch);
324 }
325 escape_next = false;
326 continue;
327 }
328 if ch == '\\' && in_string {
329 escape_next = true;
330 if in_key {
331 key.push(ch);
332 } else if in_value {
333 value.push(ch);
334 }
335 continue;
336 }
337 if ch == '"' {
338 if !in_string {
339 in_string = true;
340 if !after_colon {
341 in_key = true;
342 in_value = false;
343 } else {
344 in_value = true;
345 in_key = false;
346 }
347 } else {
348 in_string = false;
349 if in_value {
350 map.insert(key.clone(), value.clone());
351 key.clear();
352 value.clear();
353 in_key = false;
354 in_value = false;
355 after_colon = false;
356 }
357 if in_key {
358 in_key = false;
359 }
360 }
361 continue;
362 }
363 if ch == ':' && !in_string {
364 after_colon = true;
365 continue;
366 }
367 if ch == ',' && !in_string {
368 after_colon = false;
369 continue;
370 }
371 if in_key {
372 key.push(ch);
373 } else if in_value {
374 value.push(ch);
375 }
376 }
377 map
378}
379
380fn escape_json(s: &str) -> String {
381 s.replace('\\', "\\\\").replace('"', "\\\"")
382}
383
384fn content_hash(path: &Path) -> String {
385 let bytes = std::fs::read(path)
386 .unwrap_or_else(|e| panic!("Failed to read '{}': {}", path.display(), e));
387 let mut hash: u64 = 0xcbf29ce484222325;
389 for byte in &bytes {
390 hash ^= *byte as u64;
391 hash = hash.wrapping_mul(0x100000001b3);
392 }
393 format!("{:016x}", hash)
394}
395
396fn handler_dep_globs_cached(hashes: &HashMap<String, String>, dep_key: &str) -> Vec<String> {
397 match hashes.get(dep_key) {
398 Some(json_str) => parse_string_array(json_str),
399 None => vec![],
400 }
401}
402
403fn parse_string_array(json: &str) -> Vec<String> {
404 let trimmed = json.trim();
405 if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
406 return vec![];
407 }
408 let inner = &trimmed[1..trimmed.len() - 1];
409 let mut result = vec![];
410 let mut current = String::new();
411 let mut in_string = false;
412 let mut escape_next = false;
413
414 for ch in inner.chars() {
415 if escape_next {
416 current.push(ch);
417 escape_next = false;
418 continue;
419 }
420 if ch == '\\' && in_string {
421 escape_next = true;
422 continue;
423 }
424 if ch == '"' {
425 if in_string {
426 result.push(current.clone());
427 current.clear();
428 }
429 in_string = !in_string;
430 continue;
431 }
432 if in_string {
433 current.push(ch);
434 }
435 }
436 result
437}
438
439fn check_dep_hashes_changed(dep_globs: &[String], hashes: &HashMap<String, String>) -> bool {
440 for pattern in dep_globs {
441 if let Ok(paths) = expand_glob(pattern) {
442 for path in paths {
443 if !path.exists() {
444 continue;
445 }
446 let current_hash = content_hash(&path);
447 let hash_key = path.display().to_string();
448 match hashes.get(&hash_key) {
449 Some(old_hash) if *old_hash == current_hash => {}
450 _ => return true, }
452 }
453 }
454 }
455 false
456}
457
458fn expand_glob(pattern: &str) -> Result<Vec<PathBuf>, std::io::Error> {
460 let mut results = vec![];
461 let parts: Vec<&str> = pattern.split('/').collect();
462 expand_glob_recursive(Path::new("."), &parts, 0, &mut results)?;
463 Ok(results)
464}
465
466fn expand_glob_recursive(
467 base: &Path,
468 parts: &[&str],
469 idx: usize,
470 results: &mut Vec<PathBuf>,
471) -> Result<(), std::io::Error> {
472 if idx >= parts.len() {
473 if base.is_file() {
474 results.push(base.to_path_buf());
475 }
476 return Ok(());
477 }
478
479 let part = parts[idx];
480
481 if part == "**" {
482 expand_glob_recursive(base, parts, idx + 1, results)?;
484 if base.is_dir() {
485 for entry in std::fs::read_dir(base)? {
486 let entry = entry?;
487 let path = entry.path();
488 if path.is_dir() {
489 expand_glob_recursive(&path, parts, idx, results)?;
490 }
491 }
492 }
493 } else if part.contains('*') {
494 if base.is_dir() {
496 for entry in std::fs::read_dir(base)? {
497 let entry = entry?;
498 let name = entry.file_name();
499 let name_str = name.to_string_lossy();
500 if matches_wildcard(part, &name_str) {
501 expand_glob_recursive(&entry.path(), parts, idx + 1, results)?;
502 }
503 }
504 }
505 } else {
506 let next = base.join(part);
508 if next.exists() {
509 expand_glob_recursive(&next, parts, idx + 1, results)?;
510 }
511 }
512
513 Ok(())
514}
515
516fn matches_wildcard(pattern: &str, name: &str) -> bool {
517 if let Some(suffix) = pattern.strip_prefix('*') {
519 name.ends_with(suffix)
520 } else if let Some(prefix) = pattern.strip_suffix('*') {
521 name.starts_with(prefix)
522 } else {
523 pattern == name
524 }
525}
526
527fn collect_shader_files(dir: &Path) -> Vec<PathBuf> {
528 let mut files = vec![];
529 collect_shader_files_recursive(dir, &mut files);
530 files.sort();
531 files
532}
533
534fn collect_shader_files_recursive(dir: &Path, files: &mut Vec<PathBuf>) {
535 let entries = match std::fs::read_dir(dir) {
536 Ok(e) => e,
537 Err(_) => return,
538 };
539 for entry in entries {
540 let Ok(entry) = entry else { continue };
541 let path = entry.path();
542 if path.is_dir() {
543 collect_shader_files_recursive(&path, files);
544 } else if path.is_file() {
545 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
546 match ext {
547 "slang" | "hlsl" | "glsl" | "frag" => {
548 files.push(path);
549 }
550 _ => {}
551 }
552 }
553 }
554 }
555}
556
557fn compile_slang(source: &Path, rel_path: &Path, _output_dir: &Path, spirv_dir: &Path, slangc_path: Option<&Path>) -> bool {
561 let stem = rel_path.file_stem().unwrap().to_string_lossy();
562
563 let spv_path = spirv_dir.join(format!("{}.spv", stem));
565 let slangc_cmd = slangc_path.map(|p| p.as_os_str().to_owned()).unwrap_or_else(|| std::ffi::OsString::from("slangc"));
566 let slangc_status = std::process::Command::new(&slangc_cmd)
567 .arg(source)
568 .arg("-target")
569 .arg("spirv")
570 .arg("-entry")
571 .arg("main")
572 .arg("-stage")
573 .arg("fragment")
574 .arg("-o")
575 .arg(&spv_path)
576 .status();
577
578 match slangc_status {
579 Ok(status) if status.success() => {}
580 Ok(status) => {
581 println!(
582 "cargo:warning=slangc failed with exit code {} for '{}' — skipping",
583 status,
584 source.display()
585 );
586 return false;
587 }
588 Err(e) => {
589 println!(
590 "cargo:warning=Could not run slangc for '{}': {}. Is slangc installed and on PATH? Skipping.",
591 source.display(),
592 e
593 );
594 return false;
595 }
596 }
597
598 #[cfg(feature = "shader-build")]
600 {
601 let output_path = _output_dir.join(format!("{}.frag.glsl", stem));
602 if !spirv_cross_library(&spv_path, &output_path) {
603 return false;
604 } else {
605 prepend_header(&output_path, &format!("{}", rel_path.display()));
606 return true;
607 };
608 }
609 #[cfg(not(feature = "shader-build"))]
610 {
611 println!("cargo:warning=The 'shader-build' feature is not enabled, but needed for SPIR-V to GLSL conversion. Please enable the 'shader-build' feature in your [dev-dependencies]!");
612 return false;
613 }
614}
615
616#[cfg(feature = "shader-build")]
618fn spirv_cross_library(spv_path: &Path, output_path: &Path) -> bool {
619 use spirv_cross2::compile::glsl::GlslVersion;
620 use spirv_cross2::compile::CompilableTarget;
621 use spirv_cross2::targets::Glsl;
622 use spirv_cross2::{Compiler, Module};
623
624 let spv_bytes = std::fs::read(spv_path)
625 .unwrap_or_else(|e| panic!("Failed to read SPIR-V file '{}': {}", spv_path.display(), e));
626
627 assert!(
629 spv_bytes.len() % 4 == 0,
630 "SPIR-V file size must be a multiple of 4 bytes"
631 );
632 let spv_words: Vec<u32> = spv_bytes
633 .chunks_exact(4)
634 .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
635 .collect();
636
637 let module = Module::from_words(&spv_words);
638
639 let compiler = match Compiler::<Glsl>::new(module) {
640 Ok(c) => c,
641 Err(e) => {
642 println!(
643 "cargo:warning=spirv-cross2 failed to create GLSL compiler for '{}': {} — skipping",
644 spv_path.display(),
645 e
646 );
647 return false;
648 }
649 };
650
651 let mut options = Glsl::options();
652 options.version = GlslVersion::Glsl100Es;
653
654 let artifact = match compiler.compile(&options) {
655 Ok(a) => a,
656 Err(e) => {
657 println!(
658 "cargo:warning=spirv-cross2 failed to compile to GLSL: {} — skipping",
659 e
660 );
661 return false;
662 }
663 };
664
665 let glsl_source = artifact.to_string();
666
667 std::fs::write(output_path, glsl_source)
668 .unwrap_or_else(|e| panic!("Failed to write '{}': {}", output_path.display(), e));
669 true
670}
671
672fn copy_glsl(source: &Path, rel_path: &Path, output_dir: &Path) {
674 let stem = rel_path.file_stem().unwrap().to_string_lossy();
675 let output_path = output_dir.join(format!("{}.frag.glsl", stem));
676
677 let content = std::fs::read_to_string(source)
678 .unwrap_or_else(|e| panic!("Failed to read '{}': {}", source.display(), e));
679
680 let header = format!(
681 "{}\n//! Source: {}\n",
682 GENERATED_HEADER_PREFIX,
683 rel_path.display()
684 );
685
686 std::fs::write(&output_path, format!("{}{}", header, content))
687 .unwrap_or_else(|e| panic!("Failed to write '{}': {}", output_path.display(), e));
688}
689
690#[cfg(feature = "shader-build")]
692fn prepend_header(path: &Path, source_path: &str) {
693 let content = std::fs::read_to_string(path)
694 .unwrap_or_else(|e| panic!("Failed to read '{}': {}", path.display(), e));
695
696 let header = format!(
697 "{}\n//! Source: {}\n",
698 GENERATED_HEADER_PREFIX, source_path
699 );
700
701 std::fs::write(path, format!("{}{}", header, content))
702 .unwrap_or_else(|e| panic!("Failed to write '{}': {}", path.display(), e));
703}
704
705#[cfg(test)]
706mod tests {
707 use super::*;
708
709 #[test]
710 fn test_content_hash_deterministic() {
711 let dir = std::env::temp_dir().join("ply_shader_build_test");
712 std::fs::create_dir_all(&dir).unwrap();
713 let file = dir.join("test.glsl");
714 std::fs::write(&file, "void main() {}").unwrap();
715
716 let h1 = content_hash(&file);
717 let h2 = content_hash(&file);
718 assert_eq!(h1, h2);
719
720 std::fs::write(&file, "void main() { gl_FragColor = vec4(1.0); }").unwrap();
721 let h3 = content_hash(&file);
722 assert_ne!(h1, h3);
723
724 std::fs::remove_dir_all(&dir).unwrap();
725 }
726
727 #[test]
728 fn test_parse_simple_json_map() {
729 let json = r#"{ "foo": "bar", "baz": "qux" }"#;
730 let map = parse_simple_json_map(json);
731 assert_eq!(map.get("foo").unwrap(), "bar");
732 assert_eq!(map.get("baz").unwrap(), "qux");
733 }
734
735 #[test]
736 fn test_parse_string_array() {
737 let json = r#"["a","b","c"]"#;
738 let arr = parse_string_array(json);
739 assert_eq!(arr, vec!["a", "b", "c"]);
740 }
741
742 #[test]
743 fn test_hash_round_trip() {
744 let dir = std::env::temp_dir().join("ply_shader_hash_test");
745 std::fs::create_dir_all(&dir).unwrap();
746 let hash_file = dir.join("hashes.json");
747
748 let mut hashes = HashMap::default();
749 hashes.insert("foo.slang".to_string(), "abcdef0123456789".to_string());
750 hashes.insert("bar.glsl".to_string(), "9876543210fedcba".to_string());
751
752 save_hashes(&hash_file, &hashes);
753 let loaded = load_hashes(&hash_file);
754
755 assert_eq!(loaded.get("foo.slang").unwrap(), "abcdef0123456789");
756 assert_eq!(loaded.get("bar.glsl").unwrap(), "9876543210fedcba");
757
758 std::fs::remove_dir_all(&dir).unwrap();
759 }
760
761 #[test]
762 fn test_matches_wildcard() {
763 assert!(matches_wildcard("*.glsl", "test.glsl"));
764 assert!(matches_wildcard("*.glsl", "foo.glsl"));
765 assert!(!matches_wildcard("*.glsl", "test.slang"));
766 assert!(matches_wildcard("test*", "test.glsl"));
767 assert!(matches_wildcard("test*", "test_file"));
768 }
769}