1use libloading::{Library, Symbol};
2use notify::{RecursiveMode, Watcher};
3use notify_debouncer_full::new_debouncer;
4use std::fs;
5use std::path::{Path, PathBuf};
6use std::sync::{
7 Arc, Mutex,
8 atomic::{AtomicBool, AtomicU32, Ordering},
9 mpsc,
10};
11use std::thread;
12use std::time::Duration;
13
14use crate::error::HotReloaderError;
15
16pub struct LibReloader {
29 load_counter: usize,
30 lib_dir: PathBuf,
31 lib_name: String,
32 changed: Arc<AtomicBool>,
33 lib: Option<Library>,
34 watched_lib_file: PathBuf,
35 loaded_lib_file: PathBuf,
36 lib_file_hash: Arc<AtomicU32>,
37 file_change_subscribers: Arc<Mutex<Vec<mpsc::Sender<()>>>>,
38 #[cfg(target_os = "macos")]
39 codesigner: crate::codesign::CodeSigner,
40 loaded_lib_name_template: Option<String>,
41}
42
43impl LibReloader {
44 pub fn new(
52 lib_dir: impl AsRef<Path>,
53 lib_name: impl AsRef<str>,
54 file_watch_debounce: Option<Duration>,
55 loaded_lib_name_template: Option<String>,
56 ) -> Result<Self, HotReloaderError> {
57 let lib_dir = find_file_or_dir_in_parent_directories(lib_dir.as_ref())?;
60 log::debug!("found lib dir at {lib_dir:?}");
61
62 let load_counter = 0;
63
64 #[cfg(target_os = "macos")]
65 let codesigner = crate::codesign::CodeSigner::new();
66
67 let (watched_lib_file, loaded_lib_file) = watched_and_loaded_library_paths(
68 &lib_dir,
69 &lib_name,
70 load_counter,
71 &loaded_lib_name_template,
72 );
73
74 let (lib_file_hash, lib) = if watched_lib_file.exists() {
75 log::debug!("copying {watched_lib_file:?} -> {loaded_lib_file:?}");
78 fs::copy(&watched_lib_file, &loaded_lib_file)?;
79 let hash = hash_file(&loaded_lib_file);
80 #[cfg(target_os = "macos")]
81 codesigner.codesign(&loaded_lib_file);
82 (hash, Some(load_library(&loaded_lib_file)?))
83 } else {
84 log::debug!("library {watched_lib_file:?} does not yet exist");
85 (0, None)
86 };
87
88 let lib_file_hash = Arc::new(AtomicU32::new(lib_file_hash));
89 let changed = Arc::new(AtomicBool::new(false));
90 let file_change_subscribers = Arc::new(Mutex::new(Vec::new()));
91 Self::watch(
92 watched_lib_file.clone(),
93 lib_file_hash.clone(),
94 changed.clone(),
95 file_change_subscribers.clone(),
96 file_watch_debounce.unwrap_or_else(|| Duration::from_millis(500)),
97 )?;
98
99 let lib_loader = Self {
100 load_counter,
101 lib_dir,
102 lib_name: lib_name.as_ref().to_string(),
103 watched_lib_file,
104 loaded_lib_file,
105 lib,
106 lib_file_hash,
107 changed,
108 file_change_subscribers,
109 #[cfg(target_os = "macos")]
110 codesigner,
111 loaded_lib_name_template,
112 };
113
114 Ok(lib_loader)
115 }
116
117 #[doc(hidden)]
119 pub fn subscribe_to_file_changes(&mut self) -> mpsc::Receiver<()> {
120 log::trace!("subscribe to file change");
121 let (tx, rx) = mpsc::channel();
122 let mut subscribers = self.file_change_subscribers.lock().unwrap();
123 subscribers.push(tx);
124 rx
125 }
126
127 pub fn update(&mut self) -> Result<bool, HotReloaderError> {
130 if !self.changed.load(Ordering::Acquire) {
131 return Ok(false);
132 }
133 self.changed.store(false, Ordering::Release);
134
135 self.reload()?;
136
137 Ok(true)
138 }
139
140 fn reload(&mut self) -> Result<(), HotReloaderError> {
142 let Self {
143 load_counter,
144 lib_dir,
145 lib_name,
146 watched_lib_file,
147 loaded_lib_file,
148 lib,
149 loaded_lib_name_template,
150 ..
151 } = self;
152
153 log::info!("reloading lib {watched_lib_file:?}");
154
155 if let Some(lib) = lib.take() {
157 lib.close()?;
158 if loaded_lib_file.exists() {
159 let _ = fs::remove_file(&loaded_lib_file);
160 }
161 }
162
163 if watched_lib_file.exists() {
164 *load_counter += 1;
165 let (_, loaded_lib_file) = watched_and_loaded_library_paths(
166 lib_dir,
167 lib_name,
168 *load_counter,
169 loaded_lib_name_template,
170 );
171 log::trace!("copy {watched_lib_file:?} -> {loaded_lib_file:?}");
172 fs::copy(watched_lib_file, &loaded_lib_file)?;
173 self.lib_file_hash
174 .store(hash_file(&loaded_lib_file), Ordering::Release);
175 #[cfg(target_os = "macos")]
176 self.codesigner.codesign(&loaded_lib_file);
177 self.lib = Some(load_library(&loaded_lib_file)?);
178 self.loaded_lib_file = loaded_lib_file;
179 } else {
180 log::warn!("trying to reload library but it does not exist");
181 }
182
183 Ok(())
184 }
185
186 fn watch(
188 lib_file: impl AsRef<Path>,
189 lib_file_hash: Arc<AtomicU32>,
190 changed: Arc<AtomicBool>,
191 file_change_subscribers: Arc<Mutex<Vec<mpsc::Sender<()>>>>,
192 debounce: Duration,
193 ) -> Result<(), HotReloaderError> {
194 let lib_file = lib_file.as_ref().to_path_buf();
195 log::info!("start watching changes of file {}", lib_file.display());
196
197 thread::spawn(move || {
201 let (tx, rx) = mpsc::channel();
202
203 let mut debouncer =
204 new_debouncer(debounce, None, tx).expect("creating notify debouncer");
205
206 debouncer
207 .watcher()
208 .watch(&lib_file, RecursiveMode::NonRecursive)
209 .expect("watch lib file");
210
211 let signal_change = || {
221 if hash_file(&lib_file) == lib_file_hash.load(Ordering::Acquire)
222 || changed.load(Ordering::Acquire)
223 {
224 return false;
226 }
227
228 log::debug!("{lib_file:?} changed",);
229
230 changed.store(true, Ordering::Release);
231
232 let subscribers = file_change_subscribers.lock().unwrap();
234 log::trace!(
235 "sending ChangedEvent::LibFileChanged to {} subscribers",
236 subscribers.len()
237 );
238 for tx in &*subscribers {
239 let _ = tx.send(());
240 }
241
242 true
243 };
244
245 loop {
246 match rx.recv() {
247 Err(_) => {
248 log::info!("file watcher channel closed");
249 break;
250 }
251 Ok(events) => {
252 let events = match events {
253 Err(errors) => {
254 log::error!("{} file watcher error!", errors.len());
255 for err in errors {
256 log::error!(" {err}");
257 }
258 continue;
259 }
260 Ok(events) => events,
261 };
262
263 log::trace!("file change events: {events:?}");
264 let was_removed =
265 events
266 .iter()
267 .fold(false, |was_removed, event| match event.kind {
268 notify::EventKind::Create(_) | notify::EventKind::Modify(_) => {
269 false
270 }
271 notify::EventKind::Remove(_) => true,
272 _ => was_removed,
273 });
274 if was_removed || !lib_file.exists() {
276 log::debug!(
277 "{} was removed, trying to watch it again...",
278 lib_file.display()
279 );
280 }
281 loop {
282 if debouncer
283 .watcher()
284 .watch(&lib_file, RecursiveMode::NonRecursive)
285 .is_ok()
286 {
287 log::info!("watching {lib_file:?} again after removal");
288 signal_change();
289 break;
290 }
291 thread::sleep(Duration::from_millis(500));
292 }
293 }
294 }
295 }
296 });
297
298 Ok(())
299 }
300
301 pub unsafe fn get_symbol<T>(&self, name: &[u8]) -> Result<Symbol<'_, T>, HotReloaderError> {
312 unsafe {
313 match &self.lib {
314 None => Err(HotReloaderError::LibraryNotLoaded),
315 Some(lib) => Ok(lib.get(name)?),
316 }
317 }
318 }
319
320 #[doc(hidden)]
323 pub fn log_info(what: impl std::fmt::Display) {
324 log::info!("{what}");
325 }
326}
327
328impl Drop for LibReloader {
330 fn drop(&mut self) {
331 if self.loaded_lib_file.exists() {
332 log::trace!("removing {:?}", self.loaded_lib_file);
333 let _ = fs::remove_file(&self.loaded_lib_file);
334 }
335 }
336}
337
338fn watched_and_loaded_library_paths(
339 lib_dir: impl AsRef<Path>,
340 lib_name: impl AsRef<str>,
341 load_counter: usize,
342 loaded_lib_name_template: &Option<impl AsRef<str>>,
343) -> (PathBuf, PathBuf) {
344 let lib_dir = &lib_dir.as_ref();
345
346 #[cfg(target_os = "macos")]
348 let (prefix, ext) = ("lib", "dylib");
349 #[cfg(target_os = "linux")]
350 let (prefix, ext) = ("lib", "so");
351 #[cfg(target_os = "windows")]
352 let (prefix, ext) = ("", "dll");
353 let lib_name = format!("{prefix}{}", lib_name.as_ref());
354
355 let watched_lib_file = lib_dir.join(&lib_name).with_extension(ext);
356
357 let loaded_lib_filename = match loaded_lib_name_template {
358 Some(loaded_lib_name_template) => {
359 let result = loaded_lib_name_template
360 .as_ref()
361 .replace("{lib_name}", &lib_name)
362 .replace("{load_counter}", &load_counter.to_string())
363 .replace("{pid}", &std::process::id().to_string());
364 #[cfg(feature = "uuid")]
365 {
366 result.replace("{uuid}", &uuid::Uuid::new_v4().to_string())
367 }
368 #[cfg(not(feature = "uuid"))]
369 {
370 result
371 }
372 }
373 None => format!("{lib_name}-hot-{load_counter}"),
374 };
375 let loaded_lib_file = lib_dir.join(loaded_lib_filename).with_extension(ext);
376 (watched_lib_file, loaded_lib_file)
377}
378
379fn find_file_or_dir_in_parent_directories(
383 file: impl AsRef<Path>,
384) -> Result<PathBuf, HotReloaderError> {
385 let mut file = file.as_ref().to_path_buf();
386 if !file.exists()
387 && file.is_relative()
388 && let Ok(cwd) = std::env::current_dir()
389 {
390 let mut parent_dir = Some(cwd.as_path());
391 while let Some(dir) = parent_dir {
392 if dir.join(&file).exists() {
393 file = dir.join(&file);
394 break;
395 }
396 parent_dir = dir.parent();
397 }
398 }
399
400 if file.exists() {
401 Ok(file)
402 } else {
403 Err(std::io::Error::new(
404 std::io::ErrorKind::NotFound,
405 format!("file {file:?} does not exist"),
406 )
407 .into())
408 }
409}
410
411fn load_library(lib_file: impl AsRef<Path>) -> Result<Library, HotReloaderError> {
412 Ok(unsafe { Library::new(lib_file.as_ref()) }?)
413}
414
415fn hash_file(f: impl AsRef<Path>) -> u32 {
416 fs::read(f.as_ref())
417 .map(|content| crc32fast::hash(&content))
418 .unwrap_or_default()
419}