Skip to main content

oxigdal_gpu/
shader_reload.rs

1//! Shader hot-reload support for the GPU rendering pipeline.
2//!
3//! Provides [`ShaderWatcher`] for tracking WGSL shader sources and their
4//! versions, and [`HotReloadRegistry`] for mapping render pipelines to the
5//! shaders they depend on so that pipelines can be invalidated automatically
6//! when their source changes.
7
8use std::collections::{HashMap, HashSet};
9
10// ─── Entry points & stage ────────────────────────────────────────────────────
11
12/// Shader pipeline stage.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ShaderStage {
15    Vertex,
16    Fragment,
17    Compute,
18}
19
20/// A named entry point within a shader module.
21#[derive(Debug, Clone)]
22pub struct EntryPoint {
23    pub name: String,
24    pub stage: ShaderStage,
25}
26
27impl EntryPoint {
28    /// Create a new entry point descriptor.
29    pub fn new(name: impl Into<String>, stage: ShaderStage) -> Self {
30        Self {
31            name: name.into(),
32            stage,
33        }
34    }
35}
36
37// ─── ShaderSource ─────────────────────────────────────────────────────────────
38
39/// A versioned WGSL shader source record.
40#[derive(Debug, Clone)]
41pub struct ShaderSource {
42    /// Human-readable label used as the map key.
43    pub label: String,
44    /// Raw WGSL text.
45    pub wgsl_source: String,
46    /// Declared entry points (computed on insertion / update).
47    pub entry_points: Vec<EntryPoint>,
48    /// Monotonically increasing version counter; starts at `1`, increments on
49    /// every call to [`ShaderWatcher::update_source`].
50    pub version: u64,
51    /// Unix timestamp (seconds) of the last modification.
52    /// In an embedded / no-filesystem context this defaults to `0`.
53    pub last_modified: u64,
54}
55
56impl ShaderSource {
57    /// Construct an initial `ShaderSource` at version `1`.
58    fn new(label: impl Into<String>, wgsl_source: impl Into<String>) -> Self {
59        let wgsl = wgsl_source.into();
60        let entry_points = parse_entry_points(&wgsl);
61        Self {
62            label: label.into(),
63            wgsl_source: wgsl,
64            entry_points,
65            version: 1,
66            last_modified: 0,
67        }
68    }
69
70    /// Bump the version and replace the WGSL source.
71    fn bump(&mut self, new_wgsl: impl Into<String>) {
72        self.wgsl_source = new_wgsl.into();
73        self.entry_points = parse_entry_points(&self.wgsl_source);
74        self.version += 1;
75        self.last_modified = current_unix_secs();
76    }
77}
78
79/// Cheaply parse entry-point names and stages from WGSL source text.
80///
81/// Looks for `@vertex`, `@fragment`, and `@compute` annotations followed by
82/// a `fn <name>` declaration on the same or next line.
83fn parse_entry_points(wgsl: &str) -> Vec<EntryPoint> {
84    let mut entries = Vec::new();
85    let mut lines = wgsl.lines().peekable();
86
87    while let Some(line) = lines.next() {
88        let trimmed = line.trim();
89
90        // Determine if this line has a stage attribute.
91        let stage_opt = if trimmed.contains("@vertex") {
92            Some(ShaderStage::Vertex)
93        } else if trimmed.contains("@fragment") {
94            Some(ShaderStage::Fragment)
95        } else if trimmed.contains("@compute") {
96            Some(ShaderStage::Compute)
97        } else {
98            None
99        };
100
101        if let Some(stage) = stage_opt {
102            // The fn declaration may be on the same line or the next.
103            let fn_name = extract_fn_name(trimmed)
104                .or_else(|| lines.peek().and_then(|next| extract_fn_name(next.trim())));
105
106            if let Some(name) = fn_name {
107                entries.push(EntryPoint::new(name, stage));
108            }
109        }
110    }
111
112    entries
113}
114
115/// Extract the function name from a line of the form `fn <name>(...)`.
116fn extract_fn_name(line: &str) -> Option<String> {
117    let idx = line.find("fn ")?;
118    let after = line[idx + 3..].trim();
119    // Name ends at '(' or whitespace.
120    let end = after
121        .find(|c: char| c == '(' || c.is_whitespace())
122        .unwrap_or(after.len());
123    if end == 0 {
124        return None;
125    }
126    Some(after[..end].to_owned())
127}
128
129/// Returns the current time as Unix seconds.  Falls back to `0` if the
130/// platform does not expose `SystemTime`.
131fn current_unix_secs() -> u64 {
132    use std::time::{SystemTime, UNIX_EPOCH};
133    SystemTime::now()
134        .duration_since(UNIX_EPOCH)
135        .map(|d| d.as_secs())
136        .unwrap_or(0)
137}
138
139// ─── ShaderChangeEvent ────────────────────────────────────────────────────────
140
141/// Emitted by [`ShaderWatcher::poll_changes`] when a source version has
142/// changed since the last snapshot.
143#[derive(Debug, Clone)]
144pub struct ShaderChangeEvent {
145    pub label: String,
146    pub old_version: u64,
147    pub new_version: u64,
148}
149
150// ─── ShaderWatcher ────────────────────────────────────────────────────────────
151
152/// Watches a collection of named WGSL shader sources for changes.
153///
154/// In a no-filesystem context (embedded, WASM, tests) changes are driven
155/// explicitly by calling [`ShaderWatcher::update_source`].  On native
156/// targets an optional polling loop can check file modification times
157/// (see [`ShaderWatcher::poll_changes`]).
158pub struct ShaderWatcher {
159    /// Filesystem paths added via [`Self::add_path`]; stored for future polling.
160    pub watch_paths: Vec<String>,
161    /// Polling interval hint in milliseconds (informational only).
162    pub poll_interval_ms: u64,
163    /// All tracked sources keyed by label.
164    pub sources: HashMap<String, ShaderSource>,
165    /// Snapshot of versions from the last [`Self::poll_changes`] call.
166    snapshot: HashMap<String, u64>,
167}
168
169impl ShaderWatcher {
170    /// Create a new watcher with the given poll interval.
171    pub fn new(poll_interval_ms: u64) -> Self {
172        Self {
173            watch_paths: Vec::new(),
174            poll_interval_ms,
175            sources: HashMap::new(),
176            snapshot: HashMap::new(),
177        }
178    }
179
180    /// Register a filesystem path to watch.
181    ///
182    /// The label used for the source will be the path string itself.
183    /// The file is not loaded immediately; call [`Self::update_source`] with its
184    /// content or rely on a future polling implementation.
185    pub fn add_path(&mut self, path: impl Into<String>) {
186        self.watch_paths.push(path.into());
187    }
188
189    /// Register an inline WGSL source by label.  If a source with the same
190    /// label already exists it is replaced (version resets to `1`).
191    pub fn add_inline(&mut self, label: impl Into<String>, wgsl: impl Into<String>) {
192        let lbl: String = label.into();
193        let src = ShaderSource::new(lbl.clone(), wgsl);
194        self.snapshot.insert(lbl.clone(), src.version);
195        self.sources.insert(lbl, src);
196    }
197
198    /// Check whether any tracked sources have changed since the last call to
199    /// `poll_changes`.
200    ///
201    /// In the current implementation changes are detected purely by comparing
202    /// in-memory version numbers; actual filesystem polling is not yet wired up.
203    /// Returns the list of [`ShaderChangeEvent`]s describing what changed.
204    pub fn poll_changes(&mut self) -> Vec<ShaderChangeEvent> {
205        let mut events = Vec::new();
206
207        for (label, src) in &self.sources {
208            let snap_version = self.snapshot.get(label).copied().unwrap_or(0);
209            if src.version != snap_version {
210                events.push(ShaderChangeEvent {
211                    label: label.clone(),
212                    old_version: snap_version,
213                    new_version: src.version,
214                });
215            }
216        }
217
218        // Update snapshot to current state.
219        for (label, src) in &self.sources {
220            self.snapshot.insert(label.clone(), src.version);
221        }
222
223        events
224    }
225
226    /// Force-update the WGSL source for an existing label, bumping its
227    /// version.  Returns `true` on success, `false` if the label is unknown.
228    pub fn update_source(&mut self, label: &str, new_wgsl: impl Into<String>) -> bool {
229        if let Some(src) = self.sources.get_mut(label) {
230            src.bump(new_wgsl);
231            true
232        } else {
233            false
234        }
235    }
236
237    /// Look up a source by label.
238    pub fn get_source(&self, label: &str) -> Option<&ShaderSource> {
239        self.sources.get(label)
240    }
241
242    /// Return the current version of a source, or `None` if the label is
243    /// not registered.
244    pub fn source_version(&self, label: &str) -> Option<u64> {
245        self.sources.get(label).map(|s| s.version)
246    }
247}
248
249// ─── HotReloadRegistry ────────────────────────────────────────────────────────
250
251/// Maps render pipeline IDs to the shader labels they depend on and
252/// automatically marks pipelines as invalidated when their shaders change.
253pub struct HotReloadRegistry {
254    pub watcher: ShaderWatcher,
255    /// Set of pipeline IDs that need to be rebuilt.
256    pub invalidated_pipelines: HashSet<String>,
257    /// Total number of successful reloads processed.
258    pub reload_count: u64,
259    /// pipeline_id → set of shader labels it depends on.
260    pipeline_deps: HashMap<String, HashSet<String>>,
261}
262
263impl Default for HotReloadRegistry {
264    fn default() -> Self {
265        Self::new()
266    }
267}
268
269impl HotReloadRegistry {
270    /// Create a new registry with a default watcher (500 ms poll interval).
271    pub fn new() -> Self {
272        Self {
273            watcher: ShaderWatcher::new(500),
274            invalidated_pipelines: HashSet::new(),
275            reload_count: 0,
276            pipeline_deps: HashMap::new(),
277        }
278    }
279
280    /// Register a pipeline as depending on a shader label.
281    ///
282    /// A pipeline may depend on multiple shaders; call this method once per
283    /// shader dependency.
284    pub fn register_pipeline(&mut self, pipeline_id: impl Into<String>, shader_label: &str) {
285        self.pipeline_deps
286            .entry(pipeline_id.into())
287            .or_default()
288            .insert(shader_label.to_owned());
289    }
290
291    /// Poll for shader changes and mark dependent pipelines as invalidated.
292    ///
293    /// Returns the list of pipeline IDs that have been newly invalidated.
294    pub fn process_changes(&mut self) -> Vec<String> {
295        let events = self.watcher.poll_changes();
296        if events.is_empty() {
297            return Vec::new();
298        }
299
300        let changed_labels: HashSet<&str> = events.iter().map(|e| e.label.as_str()).collect();
301
302        let mut newly_invalidated = Vec::new();
303
304        for (pipeline_id, deps) in &self.pipeline_deps {
305            if deps.iter().any(|l| changed_labels.contains(l.as_str()))
306                && !self.invalidated_pipelines.contains(pipeline_id)
307            {
308                newly_invalidated.push(pipeline_id.clone());
309            }
310        }
311
312        for id in &newly_invalidated {
313            self.invalidated_pipelines.insert(id.clone());
314        }
315
316        self.reload_count += events.len() as u64;
317        newly_invalidated
318    }
319
320    /// Returns `true` if the pipeline is currently marked as invalidated.
321    pub fn is_invalidated(&self, pipeline_id: &str) -> bool {
322        self.invalidated_pipelines.contains(pipeline_id)
323    }
324
325    /// Clear the invalidation flag for a pipeline after it has been rebuilt.
326    pub fn clear_invalidated(&mut self, pipeline_id: &str) {
327        self.invalidated_pipelines.remove(pipeline_id);
328    }
329}
330
331// ─── Tests ────────────────────────────────────────────────────────────────────
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    // ── ShaderSource parsing ─────────────────────────────────────────────────
338
339    #[test]
340    fn test_parse_entry_points_compute() {
341        let wgsl = "@compute @workgroup_size(64)\nfn main() {}";
342        let eps = parse_entry_points(wgsl);
343        assert_eq!(eps.len(), 1);
344        assert_eq!(eps[0].name, "main");
345        assert_eq!(eps[0].stage, ShaderStage::Compute);
346    }
347
348    #[test]
349    fn test_parse_entry_points_vertex_fragment() {
350        let wgsl = "@vertex fn vs_main() {}\n@fragment fn fs_main() {}";
351        let eps = parse_entry_points(wgsl);
352        assert_eq!(eps.len(), 2);
353        assert!(eps.iter().any(|e| e.name == "vs_main"));
354        assert!(eps.iter().any(|e| e.name == "fs_main"));
355    }
356
357    #[test]
358    fn test_parse_no_entry_points() {
359        let wgsl = "struct Foo { x: f32 }";
360        assert!(parse_entry_points(wgsl).is_empty());
361    }
362
363    // ── ShaderWatcher ────────────────────────────────────────────────────────
364
365    #[test]
366    fn test_add_inline_and_get() {
367        let mut w = ShaderWatcher::new(100);
368        w.add_inline("my_shader", "@compute fn main() {}");
369        let src = w.get_source("my_shader");
370        assert!(src.is_some());
371        let src = src.expect("source should exist");
372        assert_eq!(src.label, "my_shader");
373        assert_eq!(src.version, 1);
374    }
375
376    #[test]
377    fn test_get_unknown_label_returns_none() {
378        let w = ShaderWatcher::new(100);
379        assert!(w.get_source("unknown").is_none());
380    }
381
382    #[test]
383    fn test_source_version_initial() {
384        let mut w = ShaderWatcher::new(100);
385        w.add_inline("s", "@compute fn main() {}");
386        assert_eq!(w.source_version("s"), Some(1));
387    }
388
389    #[test]
390    fn test_source_version_unknown() {
391        let w = ShaderWatcher::new(100);
392        assert_eq!(w.source_version("nope"), None);
393    }
394
395    #[test]
396    fn test_update_source_bumps_version() {
397        let mut w = ShaderWatcher::new(100);
398        w.add_inline("s", "@compute fn main() {}");
399        let ok = w.update_source("s", "@compute fn main_v2() {}");
400        assert!(ok);
401        assert_eq!(w.source_version("s"), Some(2));
402    }
403
404    #[test]
405    fn test_update_source_unknown_returns_false() {
406        let mut w = ShaderWatcher::new(100);
407        assert!(!w.update_source("ghost", "@compute fn x() {}"));
408    }
409
410    #[test]
411    fn test_update_source_multiple_bumps() {
412        let mut w = ShaderWatcher::new(100);
413        w.add_inline("s", "fn main() {}");
414        for expected in 2..=5_u64 {
415            w.update_source("s", format!("fn main_{expected}() {{}}"));
416            assert_eq!(w.source_version("s"), Some(expected));
417        }
418    }
419
420    #[test]
421    fn test_poll_changes_after_update() {
422        let mut w = ShaderWatcher::new(100);
423        w.add_inline("s", "@compute fn main() {}");
424        // First poll — nothing has changed since add_inline sets the snapshot.
425        let first = w.poll_changes();
426        assert!(first.is_empty(), "first poll should be empty");
427
428        // Update source, then poll again.
429        w.update_source("s", "@compute fn main_v2() {}");
430        let second = w.poll_changes();
431        assert_eq!(second.len(), 1);
432        assert_eq!(second[0].label, "s");
433        assert_eq!(second[0].old_version, 1);
434        assert_eq!(second[0].new_version, 2);
435    }
436
437    #[test]
438    fn test_poll_changes_clears_on_second_poll() {
439        let mut w = ShaderWatcher::new(100);
440        w.add_inline("s", "fn main() {}");
441        w.update_source("s", "fn main_v2() {}");
442        let _ = w.poll_changes();
443        // Without another update, second poll returns nothing.
444        assert!(w.poll_changes().is_empty());
445    }
446
447    #[test]
448    fn test_add_path_stores_path() {
449        let mut w = ShaderWatcher::new(100);
450        w.add_path("/tmp/test_shader.wgsl");
451        assert_eq!(w.watch_paths, vec!["/tmp/test_shader.wgsl".to_owned()]);
452    }
453
454    #[test]
455    fn test_multiple_inline_sources() {
456        let mut w = ShaderWatcher::new(100);
457        w.add_inline("a", "fn a() {}");
458        w.add_inline("b", "fn b() {}");
459        assert_eq!(w.sources.len(), 2);
460    }
461
462    // ── HotReloadRegistry ────────────────────────────────────────────────────
463
464    #[test]
465    fn test_registry_new_not_invalidated() {
466        let reg = HotReloadRegistry::new();
467        assert!(!reg.is_invalidated("pipeline_a"));
468    }
469
470    #[test]
471    fn test_registry_process_changes_invalidates_pipeline() {
472        let mut reg = HotReloadRegistry::new();
473        reg.watcher.add_inline("my_shader", "@compute fn main() {}");
474        reg.register_pipeline("pipeline_a", "my_shader");
475
476        // Consume the "add" snapshot diff, then update.
477        reg.watcher.poll_changes();
478        reg.watcher
479            .update_source("my_shader", "@compute fn main_v2() {}");
480
481        let invalidated = reg.process_changes();
482        assert!(invalidated.contains(&"pipeline_a".to_owned()));
483        assert!(reg.is_invalidated("pipeline_a"));
484    }
485
486    #[test]
487    fn test_registry_process_changes_no_change() {
488        let mut reg = HotReloadRegistry::new();
489        reg.watcher.add_inline("s", "@compute fn main() {}");
490        reg.register_pipeline("p", "s");
491        reg.watcher.poll_changes(); // drain snapshot diff
492        let invalidated = reg.process_changes();
493        assert!(invalidated.is_empty());
494    }
495
496    #[test]
497    fn test_registry_clear_invalidated() {
498        let mut reg = HotReloadRegistry::new();
499        reg.watcher.add_inline("s", "@compute fn main() {}");
500        reg.register_pipeline("p", "s");
501        reg.watcher.poll_changes();
502        reg.watcher.update_source("s", "@compute fn new_main() {}");
503        reg.process_changes();
504        assert!(reg.is_invalidated("p"));
505        reg.clear_invalidated("p");
506        assert!(!reg.is_invalidated("p"));
507    }
508
509    #[test]
510    fn test_registry_reload_count_increments() {
511        let mut reg = HotReloadRegistry::new();
512        reg.watcher.add_inline("s", "fn main() {}");
513        reg.register_pipeline("p", "s");
514        reg.watcher.poll_changes();
515        reg.watcher.update_source("s", "fn main_v2() {}");
516        reg.process_changes();
517        assert_eq!(reg.reload_count, 1);
518        reg.watcher.update_source("s", "fn main_v3() {}");
519        reg.process_changes();
520        assert_eq!(reg.reload_count, 2);
521    }
522
523    #[test]
524    fn test_registry_unrelated_shader_does_not_invalidate() {
525        let mut reg = HotReloadRegistry::new();
526        reg.watcher.add_inline("shader_a", "fn a() {}");
527        reg.watcher.add_inline("shader_b", "fn b() {}");
528        reg.register_pipeline("pipeline_a", "shader_a");
529        reg.watcher.poll_changes();
530
531        // Only change shader_b.
532        reg.watcher.update_source("shader_b", "fn b_v2() {}");
533        let invalidated = reg.process_changes();
534        assert!(!invalidated.contains(&"pipeline_a".to_owned()));
535        assert!(!reg.is_invalidated("pipeline_a"));
536    }
537
538    #[test]
539    fn test_entry_point_new() {
540        let ep = EntryPoint::new("vs_main", ShaderStage::Vertex);
541        assert_eq!(ep.name, "vs_main");
542        assert_eq!(ep.stage, ShaderStage::Vertex);
543    }
544
545    #[test]
546    fn test_shader_source_entry_points_populated() {
547        let mut w = ShaderWatcher::new(100);
548        w.add_inline("s", "@compute\nfn my_compute() {}");
549        let src = w.get_source("s").expect("source should exist");
550        assert_eq!(src.entry_points.len(), 1);
551        assert_eq!(src.entry_points[0].name, "my_compute");
552    }
553
554    #[test]
555    fn test_update_source_refreshes_entry_points() {
556        let mut w = ShaderWatcher::new(100);
557        w.add_inline("s", "@compute fn compute_v1() {}");
558        w.update_source("s", "@vertex fn vs_main() {}");
559        let src = w.get_source("s").expect("source should exist");
560        assert_eq!(src.entry_points[0].stage, ShaderStage::Vertex);
561        assert_eq!(src.entry_points[0].name, "vs_main");
562    }
563}