1use std::collections::{HashMap, HashSet};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ShaderStage {
15 Vertex,
16 Fragment,
17 Compute,
18}
19
20#[derive(Debug, Clone)]
22pub struct EntryPoint {
23 pub name: String,
24 pub stage: ShaderStage,
25}
26
27impl EntryPoint {
28 pub fn new(name: impl Into<String>, stage: ShaderStage) -> Self {
30 Self {
31 name: name.into(),
32 stage,
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
41pub struct ShaderSource {
42 pub label: String,
44 pub wgsl_source: String,
46 pub entry_points: Vec<EntryPoint>,
48 pub version: u64,
51 pub last_modified: u64,
54}
55
56impl ShaderSource {
57 fn new(label: impl Into<String>, wgsl_source: impl Into<String>) -> Self {
59 let wgsl = wgsl_source.into();
60 let entry_points = parse_entry_points(&wgsl);
61 Self {
62 label: label.into(),
63 wgsl_source: wgsl,
64 entry_points,
65 version: 1,
66 last_modified: 0,
67 }
68 }
69
70 fn bump(&mut self, new_wgsl: impl Into<String>) {
72 self.wgsl_source = new_wgsl.into();
73 self.entry_points = parse_entry_points(&self.wgsl_source);
74 self.version += 1;
75 self.last_modified = current_unix_secs();
76 }
77}
78
79fn parse_entry_points(wgsl: &str) -> Vec<EntryPoint> {
84 let mut entries = Vec::new();
85 let mut lines = wgsl.lines().peekable();
86
87 while let Some(line) = lines.next() {
88 let trimmed = line.trim();
89
90 let stage_opt = if trimmed.contains("@vertex") {
92 Some(ShaderStage::Vertex)
93 } else if trimmed.contains("@fragment") {
94 Some(ShaderStage::Fragment)
95 } else if trimmed.contains("@compute") {
96 Some(ShaderStage::Compute)
97 } else {
98 None
99 };
100
101 if let Some(stage) = stage_opt {
102 let fn_name = extract_fn_name(trimmed)
104 .or_else(|| lines.peek().and_then(|next| extract_fn_name(next.trim())));
105
106 if let Some(name) = fn_name {
107 entries.push(EntryPoint::new(name, stage));
108 }
109 }
110 }
111
112 entries
113}
114
115fn extract_fn_name(line: &str) -> Option<String> {
117 let idx = line.find("fn ")?;
118 let after = line[idx + 3..].trim();
119 let end = after
121 .find(|c: char| c == '(' || c.is_whitespace())
122 .unwrap_or(after.len());
123 if end == 0 {
124 return None;
125 }
126 Some(after[..end].to_owned())
127}
128
129fn current_unix_secs() -> u64 {
132 use std::time::{SystemTime, UNIX_EPOCH};
133 SystemTime::now()
134 .duration_since(UNIX_EPOCH)
135 .map(|d| d.as_secs())
136 .unwrap_or(0)
137}
138
139#[derive(Debug, Clone)]
144pub struct ShaderChangeEvent {
145 pub label: String,
146 pub old_version: u64,
147 pub new_version: u64,
148}
149
150pub struct ShaderWatcher {
159 pub watch_paths: Vec<String>,
161 pub poll_interval_ms: u64,
163 pub sources: HashMap<String, ShaderSource>,
165 snapshot: HashMap<String, u64>,
167}
168
169impl ShaderWatcher {
170 pub fn new(poll_interval_ms: u64) -> Self {
172 Self {
173 watch_paths: Vec::new(),
174 poll_interval_ms,
175 sources: HashMap::new(),
176 snapshot: HashMap::new(),
177 }
178 }
179
180 pub fn add_path(&mut self, path: impl Into<String>) {
186 self.watch_paths.push(path.into());
187 }
188
189 pub fn add_inline(&mut self, label: impl Into<String>, wgsl: impl Into<String>) {
192 let lbl: String = label.into();
193 let src = ShaderSource::new(lbl.clone(), wgsl);
194 self.snapshot.insert(lbl.clone(), src.version);
195 self.sources.insert(lbl, src);
196 }
197
198 pub fn poll_changes(&mut self) -> Vec<ShaderChangeEvent> {
205 let mut events = Vec::new();
206
207 for (label, src) in &self.sources {
208 let snap_version = self.snapshot.get(label).copied().unwrap_or(0);
209 if src.version != snap_version {
210 events.push(ShaderChangeEvent {
211 label: label.clone(),
212 old_version: snap_version,
213 new_version: src.version,
214 });
215 }
216 }
217
218 for (label, src) in &self.sources {
220 self.snapshot.insert(label.clone(), src.version);
221 }
222
223 events
224 }
225
226 pub fn update_source(&mut self, label: &str, new_wgsl: impl Into<String>) -> bool {
229 if let Some(src) = self.sources.get_mut(label) {
230 src.bump(new_wgsl);
231 true
232 } else {
233 false
234 }
235 }
236
237 pub fn get_source(&self, label: &str) -> Option<&ShaderSource> {
239 self.sources.get(label)
240 }
241
242 pub fn source_version(&self, label: &str) -> Option<u64> {
245 self.sources.get(label).map(|s| s.version)
246 }
247}
248
249pub struct HotReloadRegistry {
254 pub watcher: ShaderWatcher,
255 pub invalidated_pipelines: HashSet<String>,
257 pub reload_count: u64,
259 pipeline_deps: HashMap<String, HashSet<String>>,
261}
262
263impl Default for HotReloadRegistry {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269impl HotReloadRegistry {
270 pub fn new() -> Self {
272 Self {
273 watcher: ShaderWatcher::new(500),
274 invalidated_pipelines: HashSet::new(),
275 reload_count: 0,
276 pipeline_deps: HashMap::new(),
277 }
278 }
279
280 pub fn register_pipeline(&mut self, pipeline_id: impl Into<String>, shader_label: &str) {
285 self.pipeline_deps
286 .entry(pipeline_id.into())
287 .or_default()
288 .insert(shader_label.to_owned());
289 }
290
291 pub fn process_changes(&mut self) -> Vec<String> {
295 let events = self.watcher.poll_changes();
296 if events.is_empty() {
297 return Vec::new();
298 }
299
300 let changed_labels: HashSet<&str> = events.iter().map(|e| e.label.as_str()).collect();
301
302 let mut newly_invalidated = Vec::new();
303
304 for (pipeline_id, deps) in &self.pipeline_deps {
305 if deps.iter().any(|l| changed_labels.contains(l.as_str()))
306 && !self.invalidated_pipelines.contains(pipeline_id)
307 {
308 newly_invalidated.push(pipeline_id.clone());
309 }
310 }
311
312 for id in &newly_invalidated {
313 self.invalidated_pipelines.insert(id.clone());
314 }
315
316 self.reload_count += events.len() as u64;
317 newly_invalidated
318 }
319
320 pub fn is_invalidated(&self, pipeline_id: &str) -> bool {
322 self.invalidated_pipelines.contains(pipeline_id)
323 }
324
325 pub fn clear_invalidated(&mut self, pipeline_id: &str) {
327 self.invalidated_pipelines.remove(pipeline_id);
328 }
329}
330
331#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
340 fn test_parse_entry_points_compute() {
341 let wgsl = "@compute @workgroup_size(64)\nfn main() {}";
342 let eps = parse_entry_points(wgsl);
343 assert_eq!(eps.len(), 1);
344 assert_eq!(eps[0].name, "main");
345 assert_eq!(eps[0].stage, ShaderStage::Compute);
346 }
347
348 #[test]
349 fn test_parse_entry_points_vertex_fragment() {
350 let wgsl = "@vertex fn vs_main() {}\n@fragment fn fs_main() {}";
351 let eps = parse_entry_points(wgsl);
352 assert_eq!(eps.len(), 2);
353 assert!(eps.iter().any(|e| e.name == "vs_main"));
354 assert!(eps.iter().any(|e| e.name == "fs_main"));
355 }
356
357 #[test]
358 fn test_parse_no_entry_points() {
359 let wgsl = "struct Foo { x: f32 }";
360 assert!(parse_entry_points(wgsl).is_empty());
361 }
362
363 #[test]
366 fn test_add_inline_and_get() {
367 let mut w = ShaderWatcher::new(100);
368 w.add_inline("my_shader", "@compute fn main() {}");
369 let src = w.get_source("my_shader");
370 assert!(src.is_some());
371 let src = src.expect("source should exist");
372 assert_eq!(src.label, "my_shader");
373 assert_eq!(src.version, 1);
374 }
375
376 #[test]
377 fn test_get_unknown_label_returns_none() {
378 let w = ShaderWatcher::new(100);
379 assert!(w.get_source("unknown").is_none());
380 }
381
382 #[test]
383 fn test_source_version_initial() {
384 let mut w = ShaderWatcher::new(100);
385 w.add_inline("s", "@compute fn main() {}");
386 assert_eq!(w.source_version("s"), Some(1));
387 }
388
389 #[test]
390 fn test_source_version_unknown() {
391 let w = ShaderWatcher::new(100);
392 assert_eq!(w.source_version("nope"), None);
393 }
394
395 #[test]
396 fn test_update_source_bumps_version() {
397 let mut w = ShaderWatcher::new(100);
398 w.add_inline("s", "@compute fn main() {}");
399 let ok = w.update_source("s", "@compute fn main_v2() {}");
400 assert!(ok);
401 assert_eq!(w.source_version("s"), Some(2));
402 }
403
404 #[test]
405 fn test_update_source_unknown_returns_false() {
406 let mut w = ShaderWatcher::new(100);
407 assert!(!w.update_source("ghost", "@compute fn x() {}"));
408 }
409
410 #[test]
411 fn test_update_source_multiple_bumps() {
412 let mut w = ShaderWatcher::new(100);
413 w.add_inline("s", "fn main() {}");
414 for expected in 2..=5_u64 {
415 w.update_source("s", format!("fn main_{expected}() {{}}"));
416 assert_eq!(w.source_version("s"), Some(expected));
417 }
418 }
419
420 #[test]
421 fn test_poll_changes_after_update() {
422 let mut w = ShaderWatcher::new(100);
423 w.add_inline("s", "@compute fn main() {}");
424 let first = w.poll_changes();
426 assert!(first.is_empty(), "first poll should be empty");
427
428 w.update_source("s", "@compute fn main_v2() {}");
430 let second = w.poll_changes();
431 assert_eq!(second.len(), 1);
432 assert_eq!(second[0].label, "s");
433 assert_eq!(second[0].old_version, 1);
434 assert_eq!(second[0].new_version, 2);
435 }
436
437 #[test]
438 fn test_poll_changes_clears_on_second_poll() {
439 let mut w = ShaderWatcher::new(100);
440 w.add_inline("s", "fn main() {}");
441 w.update_source("s", "fn main_v2() {}");
442 let _ = w.poll_changes();
443 assert!(w.poll_changes().is_empty());
445 }
446
447 #[test]
448 fn test_add_path_stores_path() {
449 let mut w = ShaderWatcher::new(100);
450 w.add_path("/tmp/test_shader.wgsl");
451 assert_eq!(w.watch_paths, vec!["/tmp/test_shader.wgsl".to_owned()]);
452 }
453
454 #[test]
455 fn test_multiple_inline_sources() {
456 let mut w = ShaderWatcher::new(100);
457 w.add_inline("a", "fn a() {}");
458 w.add_inline("b", "fn b() {}");
459 assert_eq!(w.sources.len(), 2);
460 }
461
462 #[test]
465 fn test_registry_new_not_invalidated() {
466 let reg = HotReloadRegistry::new();
467 assert!(!reg.is_invalidated("pipeline_a"));
468 }
469
470 #[test]
471 fn test_registry_process_changes_invalidates_pipeline() {
472 let mut reg = HotReloadRegistry::new();
473 reg.watcher.add_inline("my_shader", "@compute fn main() {}");
474 reg.register_pipeline("pipeline_a", "my_shader");
475
476 reg.watcher.poll_changes();
478 reg.watcher
479 .update_source("my_shader", "@compute fn main_v2() {}");
480
481 let invalidated = reg.process_changes();
482 assert!(invalidated.contains(&"pipeline_a".to_owned()));
483 assert!(reg.is_invalidated("pipeline_a"));
484 }
485
486 #[test]
487 fn test_registry_process_changes_no_change() {
488 let mut reg = HotReloadRegistry::new();
489 reg.watcher.add_inline("s", "@compute fn main() {}");
490 reg.register_pipeline("p", "s");
491 reg.watcher.poll_changes(); let invalidated = reg.process_changes();
493 assert!(invalidated.is_empty());
494 }
495
496 #[test]
497 fn test_registry_clear_invalidated() {
498 let mut reg = HotReloadRegistry::new();
499 reg.watcher.add_inline("s", "@compute fn main() {}");
500 reg.register_pipeline("p", "s");
501 reg.watcher.poll_changes();
502 reg.watcher.update_source("s", "@compute fn new_main() {}");
503 reg.process_changes();
504 assert!(reg.is_invalidated("p"));
505 reg.clear_invalidated("p");
506 assert!(!reg.is_invalidated("p"));
507 }
508
509 #[test]
510 fn test_registry_reload_count_increments() {
511 let mut reg = HotReloadRegistry::new();
512 reg.watcher.add_inline("s", "fn main() {}");
513 reg.register_pipeline("p", "s");
514 reg.watcher.poll_changes();
515 reg.watcher.update_source("s", "fn main_v2() {}");
516 reg.process_changes();
517 assert_eq!(reg.reload_count, 1);
518 reg.watcher.update_source("s", "fn main_v3() {}");
519 reg.process_changes();
520 assert_eq!(reg.reload_count, 2);
521 }
522
523 #[test]
524 fn test_registry_unrelated_shader_does_not_invalidate() {
525 let mut reg = HotReloadRegistry::new();
526 reg.watcher.add_inline("shader_a", "fn a() {}");
527 reg.watcher.add_inline("shader_b", "fn b() {}");
528 reg.register_pipeline("pipeline_a", "shader_a");
529 reg.watcher.poll_changes();
530
531 reg.watcher.update_source("shader_b", "fn b_v2() {}");
533 let invalidated = reg.process_changes();
534 assert!(!invalidated.contains(&"pipeline_a".to_owned()));
535 assert!(!reg.is_invalidated("pipeline_a"));
536 }
537
538 #[test]
539 fn test_entry_point_new() {
540 let ep = EntryPoint::new("vs_main", ShaderStage::Vertex);
541 assert_eq!(ep.name, "vs_main");
542 assert_eq!(ep.stage, ShaderStage::Vertex);
543 }
544
545 #[test]
546 fn test_shader_source_entry_points_populated() {
547 let mut w = ShaderWatcher::new(100);
548 w.add_inline("s", "@compute\nfn my_compute() {}");
549 let src = w.get_source("s").expect("source should exist");
550 assert_eq!(src.entry_points.len(), 1);
551 assert_eq!(src.entry_points[0].name, "my_compute");
552 }
553
554 #[test]
555 fn test_update_source_refreshes_entry_points() {
556 let mut w = ShaderWatcher::new(100);
557 w.add_inline("s", "@compute fn compute_v1() {}");
558 w.update_source("s", "@vertex fn vs_main() {}");
559 let src = w.get_source("s").expect("source should exist");
560 assert_eq!(src.entry_points[0].stage, ShaderStage::Vertex);
561 assert_eq!(src.entry_points[0].name, "vs_main");
562 }
563}