1use std::{path::PathBuf, sync::mpsc::Receiver};
2
3use crate::*;
4use notify::{event::AccessKind, Event, EventKind, RecursiveMode, Watcher};
5
6static HOT_RELOAD: Lazy<Mutex<HotReload>> =
8 Lazy::new(|| Mutex::new(HotReload::new()));
9
10#[macro_export]
11macro_rules! reloadable_shader_source {
12 ($path:literal) => {
13 ReloadableShaderSource {
14 static_source: sprite_shader_from_fragment(include_str!($path)),
15 path: $path.to_string(),
16 }
17 };
18}
19
20pub fn watch_shader_path(
21 path: &str,
22 shader_id: ShaderId,
23) -> notify::Result<()> {
24 let path = Path::new(path).canonicalize().unwrap().to_path_buf();
25
26 let mut hot_reload = HOT_RELOAD.lock();
27 hot_reload.watch_path(path.as_path())?;
28 hot_reload.shader_paths.insert(path, shader_id);
29
30 Ok(())
31}
32
33pub fn maybe_reload_shaders(shaders: &mut ShaderMap) {
35 HOT_RELOAD.lock().maybe_reload_shaders(shaders);
36}
37
38pub struct HotReload {
39 rx: Receiver<Result<Event, notify::Error>>,
40 watcher: notify::RecommendedWatcher,
41 pub shader_paths: HashMap<PathBuf, ShaderId>,
42}
43
44impl HotReload {
45 pub fn new() -> Self {
46 info!("SHADER HOT RELOADING ENABLED!");
47
48 let (tx, rx) = std::sync::mpsc::channel();
49
50 let watcher =
51 notify::RecommendedWatcher::new(tx, Default::default()).unwrap();
52
53 Self { rx, watcher, shader_paths: HashMap::new() }
54 }
55
56 pub fn watch_path(&mut self, path: &Path) -> notify::Result<()> {
57 self.watcher.watch(path, RecursiveMode::Recursive)?;
58
59 Ok(())
60 }
61
62 pub fn maybe_reload_shaders(&self, shaders: &mut ShaderMap) -> bool {
63 let mut reload = false;
64
65 if let Ok(maybe_event) = self.rx.try_recv() {
66 match maybe_event {
67 Ok(event) => {
68 let is_close_write = matches!(
69 event.kind,
70 EventKind::Access(AccessKind::Close(
71 notify::event::AccessMode::Write
72 ))
73 );
74
75 let is_temp = event
76 .paths
77 .iter()
78 .all(|p| p.to_string_lossy().ends_with('~'));
79
80 if is_close_write && !is_temp {
81 reload = true;
82
83 self.reload_path_bufs(shaders, &event.paths);
84 }
85 }
86
87 Err(err) => eprintln!("Error: {:?}", err),
88 }
89 }
90
91 reload
92 }
93
94 fn reload_path_bufs(&self, shaders: &mut ShaderMap, paths: &[PathBuf]) {
95 for path in paths.iter().filter(|x| !x.to_string_lossy().ends_with('~'))
96 {
97 if let Some(shader_id) = self.shader_paths.get(path) {
98 match std::fs::read_to_string(path) {
99 Ok(source) => {
100 let fragment_source =
101 &sprite_shader_from_fragment(&source);
102
103 checked_update_shader(
104 shaders,
105 *shader_id,
106 fragment_source,
107 );
108 }
109
110 Err(error) => {
111 error!(
112 "Error loading a shader at {}: {:?}",
113 path.to_string_lossy(),
114 error
115 )
116 }
117 }
118 } else {
119 error!(
120 "Trying to reload shader at {} but no ShaderId defined \
121 for that path. This likely means a wrong path was passed \
122 to `create_reloadable_shader`. Existing paths: {:?}",
123 path.to_string_lossy(),
124 self.shader_paths
125 );
126 }
127 }
128 }
129}
130
131pub fn check_shader_with_naga(source: &str) -> Result<()> {
132 let module = naga::front::wgsl::parse_str(source)?;
133
134 let mut validator = naga::valid::Validator::new(
135 naga::valid::ValidationFlags::all(),
136 naga::valid::Capabilities::all(),
137 );
138
139 validator.validate(&module)?;
140
141 Ok(())
142}
143
144pub fn checked_update_shader(
147 shaders: &mut ShaderMap,
148 id: ShaderId,
149 fragment_source: &str,
150) {
151 let shader_error_id = format!("{}-shader", id);
152
153 if let Some(shader) = shaders.shaders.get_mut(&id) {
154 let final_source = build_shader_source(
155 fragment_source,
156 &shader.bindings,
157 &shader.uniform_defs,
158 );
159
160 match check_shader_with_naga(&final_source) {
161 Ok(()) => {
162 clear_error(shader_error_id);
163 shader.source = final_source;
164 }
165 Err(err) => {
166 report_error(shader_error_id, format!("SHADER ERROR: {}", err));
167 error!("SHADER COMPILE ERROR:\n{:?}", err);
168 }
169 }
170 }
171}