1use std::collections::{HashMap, HashSet};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ShaderStage {
15 Vertex,
16 Fragment,
17 Compute,
18}
19
20#[derive(Debug, Clone)]
22pub struct EntryPoint {
23 pub name: String,
24 pub stage: ShaderStage,
25}
26
27impl EntryPoint {
28 pub fn new(name: impl Into<String>, stage: ShaderStage) -> Self {
30 Self {
31 name: name.into(),
32 stage,
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
41pub struct ShaderSource {
42 pub label: String,
44 pub wgsl_source: String,
46 pub entry_points: Vec<EntryPoint>,
48 pub version: u64,
51 pub last_modified: u64,
54}
55
56impl ShaderSource {
57 fn new(label: impl Into<String>, wgsl_source: impl Into<String>) -> Self {
59 let wgsl = wgsl_source.into();
60 let entry_points = parse_entry_points(&wgsl);
61 Self {
62 label: label.into(),
63 wgsl_source: wgsl,
64 entry_points,
65 version: 1,
66 last_modified: 0,
67 }
68 }
69
70 fn bump(&mut self, new_wgsl: impl Into<String>) {
72 self.wgsl_source = new_wgsl.into();
73 self.entry_points = parse_entry_points(&self.wgsl_source);
74 self.version += 1;
75 self.last_modified = current_unix_secs();
76 }
77}
78
79fn parse_entry_points(wgsl: &str) -> Vec<EntryPoint> {
84 let mut entries = Vec::new();
85 let mut lines = wgsl.lines().peekable();
86
87 while let Some(line) = lines.next() {
88 let trimmed = line.trim();
89
90 let stage_opt = if trimmed.contains("@vertex") {
92 Some(ShaderStage::Vertex)
93 } else if trimmed.contains("@fragment") {
94 Some(ShaderStage::Fragment)
95 } else if trimmed.contains("@compute") {
96 Some(ShaderStage::Compute)
97 } else {
98 None
99 };
100
101 if let Some(stage) = stage_opt {
102 let fn_name = extract_fn_name(trimmed)
104 .or_else(|| lines.peek().and_then(|next| extract_fn_name(next.trim())));
105
106 if let Some(name) = fn_name {
107 entries.push(EntryPoint::new(name, stage));
108 }
109 }
110 }
111
112 entries
113}
114
115fn extract_fn_name(line: &str) -> Option<String> {
117 let idx = line.find("fn ")?;
118 let after = line[idx + 3..].trim();
119 let end = after
121 .find(|c: char| c == '(' || c.is_whitespace())
122 .unwrap_or(after.len());
123 if end == 0 {
124 return None;
125 }
126 Some(after[..end].to_owned())
127}
128
129fn current_unix_secs() -> u64 {
132 use std::time::{SystemTime, UNIX_EPOCH};
133 SystemTime::now()
134 .duration_since(UNIX_EPOCH)
135 .map(|d| d.as_secs())
136 .unwrap_or(0)
137}
138
139#[derive(Debug, Clone)]
144pub struct ShaderChangeEvent {
145 pub label: String,
146 pub old_version: u64,
147 pub new_version: u64,
148}
149
150pub struct ShaderWatcher {
159 pub watch_paths: Vec<String>,
161 pub poll_interval_ms: u64,
163 pub sources: HashMap<String, ShaderSource>,
165 snapshot: HashMap<String, u64>,
167}
168
169impl ShaderWatcher {
170 pub fn new(poll_interval_ms: u64) -> Self {
172 Self {
173 watch_paths: Vec::new(),
174 poll_interval_ms,
175 sources: HashMap::new(),
176 snapshot: HashMap::new(),
177 }
178 }
179
180 pub fn add_path(&mut self, path: impl Into<String>) {
186 self.watch_paths.push(path.into());
187 }
188
189 pub fn add_inline(&mut self, label: impl Into<String>, wgsl: impl Into<String>) {
192 let lbl: String = label.into();
193 let src = ShaderSource::new(lbl.clone(), wgsl);
194 self.snapshot.insert(lbl.clone(), src.version);
195 self.sources.insert(lbl, src);
196 }
197
198 pub fn poll_changes(&mut self) -> Vec<ShaderChangeEvent> {
205 let mut events = Vec::new();
206
207 for (label, src) in &self.sources {
208 let snap_version = self.snapshot.get(label).copied().unwrap_or(0);
209 if src.version != snap_version {
210 events.push(ShaderChangeEvent {
211 label: label.clone(),
212 old_version: snap_version,
213 new_version: src.version,
214 });
215 }
216 }
217
218 for (label, src) in &self.sources {
220 self.snapshot.insert(label.clone(), src.version);
221 }
222
223 events
224 }
225
226 pub fn update_source(&mut self, label: &str, new_wgsl: impl Into<String>) -> bool {
229 if let Some(src) = self.sources.get_mut(label) {
230 src.bump(new_wgsl);
231 true
232 } else {
233 false
234 }
235 }
236
237 pub fn get_source(&self, label: &str) -> Option<&ShaderSource> {
239 self.sources.get(label)
240 }
241
242 pub fn source_version(&self, label: &str) -> Option<u64> {
245 self.sources.get(label).map(|s| s.version)
246 }
247}
248
249pub struct HotReloadRegistry {
254 pub watcher: ShaderWatcher,
255 pub invalidated_pipelines: HashSet<String>,
257 pub reload_count: u64,
259 pipeline_deps: HashMap<String, HashSet<String>>,
261}
262
263impl Default for HotReloadRegistry {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269impl HotReloadRegistry {
270 pub fn new() -> Self {
272 Self {
273 watcher: ShaderWatcher::new(500),
274 invalidated_pipelines: HashSet::new(),
275 reload_count: 0,
276 pipeline_deps: HashMap::new(),
277 }
278 }
279
280 pub fn register_pipeline(&mut self, pipeline_id: impl Into<String>, shader_label: &str) {
285 self.pipeline_deps
286 .entry(pipeline_id.into())
287 .or_default()
288 .insert(shader_label.to_owned());
289 }
290
291 pub fn process_changes(&mut self) -> Vec<String> {
295 let events = self.watcher.poll_changes();
296 if events.is_empty() {
297 return Vec::new();
298 }
299
300 let changed_labels: HashSet<&str> = events.iter().map(|e| e.label.as_str()).collect();
301
302 let mut newly_invalidated = Vec::new();
303
304 for (pipeline_id, deps) in &self.pipeline_deps {
305 if deps.iter().any(|l| changed_labels.contains(l.as_str()))
306 && !self.invalidated_pipelines.contains(pipeline_id)
307 {
308 newly_invalidated.push(pipeline_id.clone());
309 }
310 }
311
312 for id in &newly_invalidated {
313 self.invalidated_pipelines.insert(id.clone());
314 }
315
316 self.reload_count += events.len() as u64;
317 newly_invalidated
318 }
319
320 pub fn is_invalidated(&self, pipeline_id: &str) -> bool {
322 self.invalidated_pipelines.contains(pipeline_id)
323 }
324
325 pub fn clear_invalidated(&mut self, pipeline_id: &str) {
327 self.invalidated_pipelines.remove(pipeline_id);
328 }
329}
330
331#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
340 fn test_parse_entry_points_compute() {
341 let wgsl = "@compute @workgroup_size(64)\nfn main() {}";
342 let eps = parse_entry_points(wgsl);
343 assert_eq!(eps.len(), 1);
344 assert_eq!(eps[0].name, "main");
345 assert_eq!(eps[0].stage, ShaderStage::Compute);
346 }
347
348 #[test]
349 fn test_parse_entry_points_vertex_fragment() {
350 let wgsl = "@vertex fn vs_main() {}\n@fragment fn fs_main() {}";
351 let eps = parse_entry_points(wgsl);
352 assert_eq!(eps.len(), 2);
353 assert!(eps.iter().any(|e| e.name == "vs_main"));
354 assert!(eps.iter().any(|e| e.name == "fs_main"));
355 }
356
357 #[test]
358 fn test_parse_no_entry_points() {
359 let wgsl = "struct Foo { x: f32 }";
360 assert!(parse_entry_points(wgsl).is_empty());
361 }
362
363 #[test]
366 fn test_add_inline_and_get() {
367 let mut w = ShaderWatcher::new(100);
368 w.add_inline("my_shader", "@compute fn main() {}");
369 let src = w.get_source("my_shader");
370 assert!(src.is_some());
371 let src = src.expect("source should exist");
372 assert_eq!(src.label, "my_shader");
373 assert_eq!(src.version, 1);
374 }
375
376 #[test]
377 fn test_get_unknown_label_returns_none() {
378 let w = ShaderWatcher::new(100);
379 assert!(w.get_source("unknown").is_none());
380 }
381
382 #[test]
383 fn test_source_version_initial() {
384 let mut w = ShaderWatcher::new(100);
385 w.add_inline("s", "@compute fn main() {}");
386 assert_eq!(w.source_version("s"), Some(1));
387 }
388
389 #[test]
390 fn test_source_version_unknown() {
391 let w = ShaderWatcher::new(100);
392 assert_eq!(w.source_version("nope"), None);
393 }
394
395 #[test]
396 fn test_update_source_bumps_version() {
397 let mut w = ShaderWatcher::new(100);
398 w.add_inline("s", "@compute fn main() {}");
399 let ok = w.update_source("s", "@compute fn main_v2() {}");
400 assert!(ok);
401 assert_eq!(w.source_version("s"), Some(2));
402 }
403
404 #[test]
405 fn test_update_source_unknown_returns_false() {
406 let mut w = ShaderWatcher::new(100);
407 assert!(!w.update_source("ghost", "@compute fn x() {}"));
408 }
409
410 #[test]
411 fn test_update_source_multiple_bumps() {
412 let mut w = ShaderWatcher::new(100);
413 w.add_inline("s", "fn main() {}");
414 for expected in 2..=5_u64 {
415 w.update_source("s", format!("fn main_{expected}() {{}}"));
416 assert_eq!(w.source_version("s"), Some(expected));
417 }
418 }
419
420 #[test]
421 fn test_poll_changes_after_update() {
422 let mut w = ShaderWatcher::new(100);
423 w.add_inline("s", "@compute fn main() {}");
424 let first = w.poll_changes();
426 assert!(first.is_empty(), "first poll should be empty");
427
428 w.update_source("s", "@compute fn main_v2() {}");
430 let second = w.poll_changes();
431 assert_eq!(second.len(), 1);
432 assert_eq!(second[0].label, "s");
433 assert_eq!(second[0].old_version, 1);
434 assert_eq!(second[0].new_version, 2);
435 }
436
437 #[test]
438 fn test_poll_changes_clears_on_second_poll() {
439 let mut w = ShaderWatcher::new(100);
440 w.add_inline("s", "fn main() {}");
441 w.update_source("s", "fn main_v2() {}");
442 let _ = w.poll_changes();
443 assert!(w.poll_changes().is_empty());
445 }
446
447 #[test]
448 fn test_add_path_stores_path() {
449 let path = std::env::temp_dir().join("oxigdal_test_shader_bx9f.wgsl");
450 let path_str = path.to_string_lossy().into_owned();
451 let mut w = ShaderWatcher::new(100);
452 w.add_path(path_str.clone());
453 assert_eq!(w.watch_paths, vec![path_str]);
454 }
455
456 #[test]
457 fn test_multiple_inline_sources() {
458 let mut w = ShaderWatcher::new(100);
459 w.add_inline("a", "fn a() {}");
460 w.add_inline("b", "fn b() {}");
461 assert_eq!(w.sources.len(), 2);
462 }
463
464 #[test]
467 fn test_registry_new_not_invalidated() {
468 let reg = HotReloadRegistry::new();
469 assert!(!reg.is_invalidated("pipeline_a"));
470 }
471
472 #[test]
473 fn test_registry_process_changes_invalidates_pipeline() {
474 let mut reg = HotReloadRegistry::new();
475 reg.watcher.add_inline("my_shader", "@compute fn main() {}");
476 reg.register_pipeline("pipeline_a", "my_shader");
477
478 reg.watcher.poll_changes();
480 reg.watcher
481 .update_source("my_shader", "@compute fn main_v2() {}");
482
483 let invalidated = reg.process_changes();
484 assert!(invalidated.contains(&"pipeline_a".to_owned()));
485 assert!(reg.is_invalidated("pipeline_a"));
486 }
487
488 #[test]
489 fn test_registry_process_changes_no_change() {
490 let mut reg = HotReloadRegistry::new();
491 reg.watcher.add_inline("s", "@compute fn main() {}");
492 reg.register_pipeline("p", "s");
493 reg.watcher.poll_changes(); let invalidated = reg.process_changes();
495 assert!(invalidated.is_empty());
496 }
497
498 #[test]
499 fn test_registry_clear_invalidated() {
500 let mut reg = HotReloadRegistry::new();
501 reg.watcher.add_inline("s", "@compute fn main() {}");
502 reg.register_pipeline("p", "s");
503 reg.watcher.poll_changes();
504 reg.watcher.update_source("s", "@compute fn new_main() {}");
505 reg.process_changes();
506 assert!(reg.is_invalidated("p"));
507 reg.clear_invalidated("p");
508 assert!(!reg.is_invalidated("p"));
509 }
510
511 #[test]
512 fn test_registry_reload_count_increments() {
513 let mut reg = HotReloadRegistry::new();
514 reg.watcher.add_inline("s", "fn main() {}");
515 reg.register_pipeline("p", "s");
516 reg.watcher.poll_changes();
517 reg.watcher.update_source("s", "fn main_v2() {}");
518 reg.process_changes();
519 assert_eq!(reg.reload_count, 1);
520 reg.watcher.update_source("s", "fn main_v3() {}");
521 reg.process_changes();
522 assert_eq!(reg.reload_count, 2);
523 }
524
525 #[test]
526 fn test_registry_unrelated_shader_does_not_invalidate() {
527 let mut reg = HotReloadRegistry::new();
528 reg.watcher.add_inline("shader_a", "fn a() {}");
529 reg.watcher.add_inline("shader_b", "fn b() {}");
530 reg.register_pipeline("pipeline_a", "shader_a");
531 reg.watcher.poll_changes();
532
533 reg.watcher.update_source("shader_b", "fn b_v2() {}");
535 let invalidated = reg.process_changes();
536 assert!(!invalidated.contains(&"pipeline_a".to_owned()));
537 assert!(!reg.is_invalidated("pipeline_a"));
538 }
539
540 #[test]
541 fn test_entry_point_new() {
542 let ep = EntryPoint::new("vs_main", ShaderStage::Vertex);
543 assert_eq!(ep.name, "vs_main");
544 assert_eq!(ep.stage, ShaderStage::Vertex);
545 }
546
547 #[test]
548 fn test_shader_source_entry_points_populated() {
549 let mut w = ShaderWatcher::new(100);
550 w.add_inline("s", "@compute\nfn my_compute() {}");
551 let src = w.get_source("s").expect("source should exist");
552 assert_eq!(src.entry_points.len(), 1);
553 assert_eq!(src.entry_points[0].name, "my_compute");
554 }
555
556 #[test]
557 fn test_update_source_refreshes_entry_points() {
558 let mut w = ShaderWatcher::new(100);
559 w.add_inline("s", "@compute fn compute_v1() {}");
560 w.update_source("s", "@vertex fn vs_main() {}");
561 let src = w.get_source("s").expect("source should exist");
562 assert_eq!(src.entry_points[0].stage, ShaderStage::Vertex);
563 assert_eq!(src.entry_points[0].name, "vs_main");
564 }
565}