use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ShaderStage {
Vertex,
Fragment,
Compute,
}
#[derive(Debug, Clone)]
pub struct EntryPoint {
pub name: String,
pub stage: ShaderStage,
}
impl EntryPoint {
pub fn new(name: impl Into<String>, stage: ShaderStage) -> Self {
Self {
name: name.into(),
stage,
}
}
}
#[derive(Debug, Clone)]
pub struct ShaderSource {
pub label: String,
pub wgsl_source: String,
pub entry_points: Vec<EntryPoint>,
pub version: u64,
pub last_modified: u64,
}
impl ShaderSource {
fn new(label: impl Into<String>, wgsl_source: impl Into<String>) -> Self {
let wgsl = wgsl_source.into();
let entry_points = parse_entry_points(&wgsl);
Self {
label: label.into(),
wgsl_source: wgsl,
entry_points,
version: 1,
last_modified: 0,
}
}
fn bump(&mut self, new_wgsl: impl Into<String>) {
self.wgsl_source = new_wgsl.into();
self.entry_points = parse_entry_points(&self.wgsl_source);
self.version += 1;
self.last_modified = current_unix_secs();
}
}
fn parse_entry_points(wgsl: &str) -> Vec<EntryPoint> {
let mut entries = Vec::new();
let mut lines = wgsl.lines().peekable();
while let Some(line) = lines.next() {
let trimmed = line.trim();
let stage_opt = if trimmed.contains("@vertex") {
Some(ShaderStage::Vertex)
} else if trimmed.contains("@fragment") {
Some(ShaderStage::Fragment)
} else if trimmed.contains("@compute") {
Some(ShaderStage::Compute)
} else {
None
};
if let Some(stage) = stage_opt {
let fn_name = extract_fn_name(trimmed)
.or_else(|| lines.peek().and_then(|next| extract_fn_name(next.trim())));
if let Some(name) = fn_name {
entries.push(EntryPoint::new(name, stage));
}
}
}
entries
}
fn extract_fn_name(line: &str) -> Option<String> {
let idx = line.find("fn ")?;
let after = line[idx + 3..].trim();
let end = after
.find(|c: char| c == '(' || c.is_whitespace())
.unwrap_or(after.len());
if end == 0 {
return None;
}
Some(after[..end].to_owned())
}
fn current_unix_secs() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
#[derive(Debug, Clone)]
pub struct ShaderChangeEvent {
pub label: String,
pub old_version: u64,
pub new_version: u64,
}
pub struct ShaderWatcher {
pub watch_paths: Vec<String>,
pub poll_interval_ms: u64,
pub sources: HashMap<String, ShaderSource>,
snapshot: HashMap<String, u64>,
}
impl ShaderWatcher {
pub fn new(poll_interval_ms: u64) -> Self {
Self {
watch_paths: Vec::new(),
poll_interval_ms,
sources: HashMap::new(),
snapshot: HashMap::new(),
}
}
pub fn add_path(&mut self, path: impl Into<String>) {
self.watch_paths.push(path.into());
}
pub fn add_inline(&mut self, label: impl Into<String>, wgsl: impl Into<String>) {
let lbl: String = label.into();
let src = ShaderSource::new(lbl.clone(), wgsl);
self.snapshot.insert(lbl.clone(), src.version);
self.sources.insert(lbl, src);
}
pub fn poll_changes(&mut self) -> Vec<ShaderChangeEvent> {
let mut events = Vec::new();
for (label, src) in &self.sources {
let snap_version = self.snapshot.get(label).copied().unwrap_or(0);
if src.version != snap_version {
events.push(ShaderChangeEvent {
label: label.clone(),
old_version: snap_version,
new_version: src.version,
});
}
}
for (label, src) in &self.sources {
self.snapshot.insert(label.clone(), src.version);
}
events
}
pub fn update_source(&mut self, label: &str, new_wgsl: impl Into<String>) -> bool {
if let Some(src) = self.sources.get_mut(label) {
src.bump(new_wgsl);
true
} else {
false
}
}
pub fn get_source(&self, label: &str) -> Option<&ShaderSource> {
self.sources.get(label)
}
pub fn source_version(&self, label: &str) -> Option<u64> {
self.sources.get(label).map(|s| s.version)
}
}
pub struct HotReloadRegistry {
pub watcher: ShaderWatcher,
pub invalidated_pipelines: HashSet<String>,
pub reload_count: u64,
pipeline_deps: HashMap<String, HashSet<String>>,
}
impl Default for HotReloadRegistry {
fn default() -> Self {
Self::new()
}
}
impl HotReloadRegistry {
pub fn new() -> Self {
Self {
watcher: ShaderWatcher::new(500),
invalidated_pipelines: HashSet::new(),
reload_count: 0,
pipeline_deps: HashMap::new(),
}
}
pub fn register_pipeline(&mut self, pipeline_id: impl Into<String>, shader_label: &str) {
self.pipeline_deps
.entry(pipeline_id.into())
.or_default()
.insert(shader_label.to_owned());
}
pub fn process_changes(&mut self) -> Vec<String> {
let events = self.watcher.poll_changes();
if events.is_empty() {
return Vec::new();
}
let changed_labels: HashSet<&str> = events.iter().map(|e| e.label.as_str()).collect();
let mut newly_invalidated = Vec::new();
for (pipeline_id, deps) in &self.pipeline_deps {
if deps.iter().any(|l| changed_labels.contains(l.as_str()))
&& !self.invalidated_pipelines.contains(pipeline_id)
{
newly_invalidated.push(pipeline_id.clone());
}
}
for id in &newly_invalidated {
self.invalidated_pipelines.insert(id.clone());
}
self.reload_count += events.len() as u64;
newly_invalidated
}
pub fn is_invalidated(&self, pipeline_id: &str) -> bool {
self.invalidated_pipelines.contains(pipeline_id)
}
pub fn clear_invalidated(&mut self, pipeline_id: &str) {
self.invalidated_pipelines.remove(pipeline_id);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_entry_points_compute() {
let wgsl = "@compute @workgroup_size(64)\nfn main() {}";
let eps = parse_entry_points(wgsl);
assert_eq!(eps.len(), 1);
assert_eq!(eps[0].name, "main");
assert_eq!(eps[0].stage, ShaderStage::Compute);
}
#[test]
fn test_parse_entry_points_vertex_fragment() {
let wgsl = "@vertex fn vs_main() {}\n@fragment fn fs_main() {}";
let eps = parse_entry_points(wgsl);
assert_eq!(eps.len(), 2);
assert!(eps.iter().any(|e| e.name == "vs_main"));
assert!(eps.iter().any(|e| e.name == "fs_main"));
}
#[test]
fn test_parse_no_entry_points() {
let wgsl = "struct Foo { x: f32 }";
assert!(parse_entry_points(wgsl).is_empty());
}
#[test]
fn test_add_inline_and_get() {
let mut w = ShaderWatcher::new(100);
w.add_inline("my_shader", "@compute fn main() {}");
let src = w.get_source("my_shader");
assert!(src.is_some());
let src = src.expect("source should exist");
assert_eq!(src.label, "my_shader");
assert_eq!(src.version, 1);
}
#[test]
fn test_get_unknown_label_returns_none() {
let w = ShaderWatcher::new(100);
assert!(w.get_source("unknown").is_none());
}
#[test]
fn test_source_version_initial() {
let mut w = ShaderWatcher::new(100);
w.add_inline("s", "@compute fn main() {}");
assert_eq!(w.source_version("s"), Some(1));
}
#[test]
fn test_source_version_unknown() {
let w = ShaderWatcher::new(100);
assert_eq!(w.source_version("nope"), None);
}
#[test]
fn test_update_source_bumps_version() {
let mut w = ShaderWatcher::new(100);
w.add_inline("s", "@compute fn main() {}");
let ok = w.update_source("s", "@compute fn main_v2() {}");
assert!(ok);
assert_eq!(w.source_version("s"), Some(2));
}
#[test]
fn test_update_source_unknown_returns_false() {
let mut w = ShaderWatcher::new(100);
assert!(!w.update_source("ghost", "@compute fn x() {}"));
}
#[test]
fn test_update_source_multiple_bumps() {
let mut w = ShaderWatcher::new(100);
w.add_inline("s", "fn main() {}");
for expected in 2..=5_u64 {
w.update_source("s", format!("fn main_{expected}() {{}}"));
assert_eq!(w.source_version("s"), Some(expected));
}
}
#[test]
fn test_poll_changes_after_update() {
let mut w = ShaderWatcher::new(100);
w.add_inline("s", "@compute fn main() {}");
let first = w.poll_changes();
assert!(first.is_empty(), "first poll should be empty");
w.update_source("s", "@compute fn main_v2() {}");
let second = w.poll_changes();
assert_eq!(second.len(), 1);
assert_eq!(second[0].label, "s");
assert_eq!(second[0].old_version, 1);
assert_eq!(second[0].new_version, 2);
}
#[test]
fn test_poll_changes_clears_on_second_poll() {
let mut w = ShaderWatcher::new(100);
w.add_inline("s", "fn main() {}");
w.update_source("s", "fn main_v2() {}");
let _ = w.poll_changes();
assert!(w.poll_changes().is_empty());
}
#[test]
fn test_add_path_stores_path() {
let path = std::env::temp_dir().join("oxigdal_test_shader_bx9f.wgsl");
let path_str = path.to_string_lossy().into_owned();
let mut w = ShaderWatcher::new(100);
w.add_path(path_str.clone());
assert_eq!(w.watch_paths, vec![path_str]);
}
#[test]
fn test_multiple_inline_sources() {
let mut w = ShaderWatcher::new(100);
w.add_inline("a", "fn a() {}");
w.add_inline("b", "fn b() {}");
assert_eq!(w.sources.len(), 2);
}
#[test]
fn test_registry_new_not_invalidated() {
let reg = HotReloadRegistry::new();
assert!(!reg.is_invalidated("pipeline_a"));
}
#[test]
fn test_registry_process_changes_invalidates_pipeline() {
let mut reg = HotReloadRegistry::new();
reg.watcher.add_inline("my_shader", "@compute fn main() {}");
reg.register_pipeline("pipeline_a", "my_shader");
reg.watcher.poll_changes();
reg.watcher
.update_source("my_shader", "@compute fn main_v2() {}");
let invalidated = reg.process_changes();
assert!(invalidated.contains(&"pipeline_a".to_owned()));
assert!(reg.is_invalidated("pipeline_a"));
}
#[test]
fn test_registry_process_changes_no_change() {
let mut reg = HotReloadRegistry::new();
reg.watcher.add_inline("s", "@compute fn main() {}");
reg.register_pipeline("p", "s");
reg.watcher.poll_changes(); let invalidated = reg.process_changes();
assert!(invalidated.is_empty());
}
#[test]
fn test_registry_clear_invalidated() {
let mut reg = HotReloadRegistry::new();
reg.watcher.add_inline("s", "@compute fn main() {}");
reg.register_pipeline("p", "s");
reg.watcher.poll_changes();
reg.watcher.update_source("s", "@compute fn new_main() {}");
reg.process_changes();
assert!(reg.is_invalidated("p"));
reg.clear_invalidated("p");
assert!(!reg.is_invalidated("p"));
}
#[test]
fn test_registry_reload_count_increments() {
let mut reg = HotReloadRegistry::new();
reg.watcher.add_inline("s", "fn main() {}");
reg.register_pipeline("p", "s");
reg.watcher.poll_changes();
reg.watcher.update_source("s", "fn main_v2() {}");
reg.process_changes();
assert_eq!(reg.reload_count, 1);
reg.watcher.update_source("s", "fn main_v3() {}");
reg.process_changes();
assert_eq!(reg.reload_count, 2);
}
#[test]
fn test_registry_unrelated_shader_does_not_invalidate() {
let mut reg = HotReloadRegistry::new();
reg.watcher.add_inline("shader_a", "fn a() {}");
reg.watcher.add_inline("shader_b", "fn b() {}");
reg.register_pipeline("pipeline_a", "shader_a");
reg.watcher.poll_changes();
reg.watcher.update_source("shader_b", "fn b_v2() {}");
let invalidated = reg.process_changes();
assert!(!invalidated.contains(&"pipeline_a".to_owned()));
assert!(!reg.is_invalidated("pipeline_a"));
}
#[test]
fn test_entry_point_new() {
let ep = EntryPoint::new("vs_main", ShaderStage::Vertex);
assert_eq!(ep.name, "vs_main");
assert_eq!(ep.stage, ShaderStage::Vertex);
}
#[test]
fn test_shader_source_entry_points_populated() {
let mut w = ShaderWatcher::new(100);
w.add_inline("s", "@compute\nfn my_compute() {}");
let src = w.get_source("s").expect("source should exist");
assert_eq!(src.entry_points.len(), 1);
assert_eq!(src.entry_points[0].name, "my_compute");
}
#[test]
fn test_update_source_refreshes_entry_points() {
let mut w = ShaderWatcher::new(100);
w.add_inline("s", "@compute fn compute_v1() {}");
w.update_source("s", "@vertex fn vs_main() {}");
let src = w.get_source("s").expect("source should exist");
assert_eq!(src.entry_points[0].stage, ShaderStage::Vertex);
assert_eq!(src.entry_points[0].name, "vs_main");
}
}