hot_lib_reloader/
lib_reloader.rs

1use libloading::{Library, Symbol};
2use notify::{RecursiveMode, Watcher};
3use notify_debouncer_full::new_debouncer;
4use std::fs;
5use std::path::{Path, PathBuf};
6use std::sync::{
7    Arc, Mutex,
8    atomic::{AtomicBool, AtomicU32, Ordering},
9    mpsc,
10};
11use std::thread;
12use std::time::Duration;
13
14use crate::error::HotReloaderError;
15
16/// Manages watches a library (dylib) file, loads it using
17/// [`libloading::Library`] and [provides access to its
18/// symbols](LibReloader::get_symbol). When the library changes, [`LibReloader`]
19/// is able to unload the old version and reload the new version through
20/// [`LibReloader::update`].
21///
22/// Note that the [`LibReloader`] itself will not actively update, i.e. does not
23/// manage an update thread calling the update function. This is normally
24/// managed by the [`hot_lib_reloader_macro::hot_module`] macro that also
25/// manages the [about-to-load and load](crate::LibReloadNotifier) notifications.
26///
27/// It can load symbols from the library with [LibReloader::get_symbol].
28pub struct LibReloader {
29    load_counter: usize,
30    lib_dir: PathBuf,
31    lib_name: String,
32    changed: Arc<AtomicBool>,
33    lib: Option<Library>,
34    watched_lib_file: PathBuf,
35    loaded_lib_file: PathBuf,
36    lib_file_hash: Arc<AtomicU32>,
37    file_change_subscribers: Arc<Mutex<Vec<mpsc::Sender<()>>>>,
38    #[cfg(target_os = "macos")]
39    codesigner: crate::codesign::CodeSigner,
40    loaded_lib_name_template: Option<String>,
41}
42
43impl LibReloader {
44    /// Creates a LibReloader.
45    ///  `lib_dir` is expected to be the location where the library to use can
46    /// be found. Probably `target/debug` normally.
47    /// `lib_name` is the name of the library, not(!) the file name. It should
48    /// normally be just the crate name of the cargo project you want to hot-reload.
49    /// LibReloader will take care to figure out the actual file name with
50    /// platform-specific prefix and extension.
51    pub fn new(
52        lib_dir: impl AsRef<Path>,
53        lib_name: impl AsRef<str>,
54        file_watch_debounce: Option<Duration>,
55        loaded_lib_name_template: Option<String>,
56    ) -> Result<Self, HotReloaderError> {
57        // find the target dir in which the build is happening and where we should find
58        // the library
59        let lib_dir = find_file_or_dir_in_parent_directories(lib_dir.as_ref())?;
60        log::debug!("found lib dir at {lib_dir:?}");
61
62        let load_counter = 0;
63
64        #[cfg(target_os = "macos")]
65        let codesigner = crate::codesign::CodeSigner::new();
66
67        let (watched_lib_file, loaded_lib_file) = watched_and_loaded_library_paths(
68            &lib_dir,
69            &lib_name,
70            load_counter,
71            &loaded_lib_name_template,
72        );
73
74        let (lib_file_hash, lib) = if watched_lib_file.exists() {
75            // We don't load the actual lib because this can get problems e.g. on Windows
76            // where a file lock would be held, preventing the lib from changing later.
77            log::debug!("copying {watched_lib_file:?} -> {loaded_lib_file:?}");
78            fs::copy(&watched_lib_file, &loaded_lib_file)?;
79            let hash = hash_file(&loaded_lib_file);
80            #[cfg(target_os = "macos")]
81            codesigner.codesign(&loaded_lib_file);
82            (hash, Some(load_library(&loaded_lib_file)?))
83        } else {
84            log::debug!("library {watched_lib_file:?} does not yet exist");
85            (0, None)
86        };
87
88        let lib_file_hash = Arc::new(AtomicU32::new(lib_file_hash));
89        let changed = Arc::new(AtomicBool::new(false));
90        let file_change_subscribers = Arc::new(Mutex::new(Vec::new()));
91        Self::watch(
92            watched_lib_file.clone(),
93            lib_file_hash.clone(),
94            changed.clone(),
95            file_change_subscribers.clone(),
96            file_watch_debounce.unwrap_or_else(|| Duration::from_millis(500)),
97        )?;
98
99        let lib_loader = Self {
100            load_counter,
101            lib_dir,
102            lib_name: lib_name.as_ref().to_string(),
103            watched_lib_file,
104            loaded_lib_file,
105            lib,
106            lib_file_hash,
107            changed,
108            file_change_subscribers,
109            #[cfg(target_os = "macos")]
110            codesigner,
111            loaded_lib_name_template,
112        };
113
114        Ok(lib_loader)
115    }
116
117    // needs to be public as it is used inside the hot_module macro.
118    #[doc(hidden)]
119    pub fn subscribe_to_file_changes(&mut self) -> mpsc::Receiver<()> {
120        log::trace!("subscribe to file change");
121        let (tx, rx) = mpsc::channel();
122        let mut subscribers = self.file_change_subscribers.lock().unwrap();
123        subscribers.push(tx);
124        rx
125    }
126
127    /// Checks if the watched library has changed. If it has, reload it and return
128    /// true. Otherwise return false.
129    pub fn update(&mut self) -> Result<bool, HotReloaderError> {
130        if !self.changed.load(Ordering::Acquire) {
131            return Ok(false);
132        }
133        self.changed.store(false, Ordering::Release);
134
135        self.reload()?;
136
137        Ok(true)
138    }
139
140    /// Reload library `self.lib_file`.
141    fn reload(&mut self) -> Result<(), HotReloaderError> {
142        let Self {
143            load_counter,
144            lib_dir,
145            lib_name,
146            watched_lib_file,
147            loaded_lib_file,
148            lib,
149            loaded_lib_name_template,
150            ..
151        } = self;
152
153        log::info!("reloading lib {watched_lib_file:?}");
154
155        // Close the loaded lib, copy the new lib to a file we can load, then load it.
156        if let Some(lib) = lib.take() {
157            lib.close()?;
158            if loaded_lib_file.exists() {
159                let _ = fs::remove_file(&loaded_lib_file);
160            }
161        }
162
163        if watched_lib_file.exists() {
164            *load_counter += 1;
165            let (_, loaded_lib_file) = watched_and_loaded_library_paths(
166                lib_dir,
167                lib_name,
168                *load_counter,
169                loaded_lib_name_template,
170            );
171            log::trace!("copy {watched_lib_file:?} -> {loaded_lib_file:?}");
172            fs::copy(watched_lib_file, &loaded_lib_file)?;
173            self.lib_file_hash
174                .store(hash_file(&loaded_lib_file), Ordering::Release);
175            #[cfg(target_os = "macos")]
176            self.codesigner.codesign(&loaded_lib_file);
177            self.lib = Some(load_library(&loaded_lib_file)?);
178            self.loaded_lib_file = loaded_lib_file;
179        } else {
180            log::warn!("trying to reload library but it does not exist");
181        }
182
183        Ok(())
184    }
185
186    /// Watch for changes of `lib_file`.
187    fn watch(
188        lib_file: impl AsRef<Path>,
189        lib_file_hash: Arc<AtomicU32>,
190        changed: Arc<AtomicBool>,
191        file_change_subscribers: Arc<Mutex<Vec<mpsc::Sender<()>>>>,
192        debounce: Duration,
193    ) -> Result<(), HotReloaderError> {
194        let lib_file = lib_file.as_ref().to_path_buf();
195        log::info!("start watching changes of file {}", lib_file.display());
196
197        // File watcher thread. We watch `self.lib_file`, when it changes and we haven't
198        // a pending change still waiting to be loaded, set `self.changed` to true. This
199        // then gets picked up by `self.update`.
200        thread::spawn(move || {
201            let (tx, rx) = mpsc::channel();
202
203            let mut debouncer =
204                new_debouncer(debounce, None, tx).expect("creating notify debouncer");
205
206            debouncer
207                .watcher()
208                .watch(&lib_file, RecursiveMode::NonRecursive)
209                .expect("watch lib file");
210
211            // debouncer
212            //     .cache()
213            //     .add_root(dir.path(), RecursiveMode::Recursive);
214
215            // let mut watcher = RecommendedWatcher::new(tx, Config::default()).unwrap();
216            // watcher
217            //     .watch(&lib_file, RecursiveMode::NonRecursive)
218            //     .expect("watch lib file");
219
220            let signal_change = || {
221                if hash_file(&lib_file) == lib_file_hash.load(Ordering::Acquire)
222                    || changed.load(Ordering::Acquire)
223                {
224                    // file not changed
225                    return false;
226                }
227
228                log::debug!("{lib_file:?} changed",);
229
230                changed.store(true, Ordering::Release);
231
232                // inform subscribers
233                let subscribers = file_change_subscribers.lock().unwrap();
234                log::trace!(
235                    "sending ChangedEvent::LibFileChanged to {} subscribers",
236                    subscribers.len()
237                );
238                for tx in &*subscribers {
239                    let _ = tx.send(());
240                }
241
242                true
243            };
244
245            loop {
246                match rx.recv() {
247                    Err(_) => {
248                        log::info!("file watcher channel closed");
249                        break;
250                    }
251                    Ok(events) => {
252                        let events = match events {
253                            Err(errors) => {
254                                log::error!("{} file watcher error!", errors.len());
255                                for err in errors {
256                                    log::error!("  {err}");
257                                }
258                                continue;
259                            }
260                            Ok(events) => events,
261                        };
262
263                        log::trace!("file change events: {events:?}");
264                        let was_removed =
265                            events
266                                .iter()
267                                .fold(false, |was_removed, event| match event.kind {
268                                    notify::EventKind::Create(_) | notify::EventKind::Modify(_) => {
269                                        false
270                                    }
271                                    notify::EventKind::Remove(_) => true,
272                                    _ => was_removed,
273                                });
274                        // just one hard link removed?
275                        if was_removed || !lib_file.exists() {
276                            log::debug!(
277                                "{} was removed, trying to watch it again...",
278                                lib_file.display()
279                            );
280                        }
281                        loop {
282                            if debouncer
283                                .watcher()
284                                .watch(&lib_file, RecursiveMode::NonRecursive)
285                                .is_ok()
286                            {
287                                log::info!("watching {lib_file:?} again after removal");
288                                signal_change();
289                                break;
290                            }
291                            thread::sleep(Duration::from_millis(500));
292                        }
293                    }
294                }
295            }
296        });
297
298        Ok(())
299    }
300
301    /// Get a pointer to a function or static variable by symbol name. Just a
302    /// wrapper around [libloading::Library::get].
303    ///
304    /// The `symbol` may not contain any null bytes, with the exception of the
305    /// last byte. Providing a null-terminated `symbol` may help to avoid an
306    /// allocation. The symbol is interpreted as is, no mangling.
307    ///
308    /// # Safety
309    ///
310    /// Users of this API must specify the correct type of the function or variable loaded.
311    pub unsafe fn get_symbol<T>(&self, name: &[u8]) -> Result<Symbol<'_, T>, HotReloaderError> {
312        unsafe {
313            match &self.lib {
314                None => Err(HotReloaderError::LibraryNotLoaded),
315                Some(lib) => Ok(lib.get(name)?),
316            }
317        }
318    }
319
320    /// Helper to log from the macro without requiring the user to have the log
321    /// crate around
322    #[doc(hidden)]
323    pub fn log_info(what: impl std::fmt::Display) {
324        log::info!("{what}");
325    }
326}
327
328/// Deletes the currently loaded lib file if it exists
329impl Drop for LibReloader {
330    fn drop(&mut self) {
331        if self.loaded_lib_file.exists() {
332            log::trace!("removing {:?}", self.loaded_lib_file);
333            let _ = fs::remove_file(&self.loaded_lib_file);
334        }
335    }
336}
337
338fn watched_and_loaded_library_paths(
339    lib_dir: impl AsRef<Path>,
340    lib_name: impl AsRef<str>,
341    load_counter: usize,
342    loaded_lib_name_template: &Option<impl AsRef<str>>,
343) -> (PathBuf, PathBuf) {
344    let lib_dir = &lib_dir.as_ref();
345
346    // sort out os dependent file name
347    #[cfg(target_os = "macos")]
348    let (prefix, ext) = ("lib", "dylib");
349    #[cfg(target_os = "linux")]
350    let (prefix, ext) = ("lib", "so");
351    #[cfg(target_os = "windows")]
352    let (prefix, ext) = ("", "dll");
353    let lib_name = format!("{prefix}{}", lib_name.as_ref());
354
355    let watched_lib_file = lib_dir.join(&lib_name).with_extension(ext);
356
357    let loaded_lib_filename = match loaded_lib_name_template {
358        Some(loaded_lib_name_template) => {
359            let result = loaded_lib_name_template
360                .as_ref()
361                .replace("{lib_name}", &lib_name)
362                .replace("{load_counter}", &load_counter.to_string())
363                .replace("{pid}", &std::process::id().to_string());
364            #[cfg(feature = "uuid")]
365            {
366                result.replace("{uuid}", &uuid::Uuid::new_v4().to_string())
367            }
368            #[cfg(not(feature = "uuid"))]
369            {
370                result
371            }
372        }
373        None => format!("{lib_name}-hot-{load_counter}"),
374    };
375    let loaded_lib_file = lib_dir.join(loaded_lib_filename).with_extension(ext);
376    (watched_lib_file, loaded_lib_file)
377}
378
379/// Try to find that might be a relative path such as `target/debug/` by walking
380/// up the directories, starting from cwd. This helps finding the lib when the
381/// app was started from a directory that is not the project/workspace root.
382fn find_file_or_dir_in_parent_directories(
383    file: impl AsRef<Path>,
384) -> Result<PathBuf, HotReloaderError> {
385    let mut file = file.as_ref().to_path_buf();
386    if !file.exists()
387        && file.is_relative()
388        && let Ok(cwd) = std::env::current_dir()
389    {
390        let mut parent_dir = Some(cwd.as_path());
391        while let Some(dir) = parent_dir {
392            if dir.join(&file).exists() {
393                file = dir.join(&file);
394                break;
395            }
396            parent_dir = dir.parent();
397        }
398    }
399
400    if file.exists() {
401        Ok(file)
402    } else {
403        Err(std::io::Error::new(
404            std::io::ErrorKind::NotFound,
405            format!("file {file:?} does not exist"),
406        )
407        .into())
408    }
409}
410
411fn load_library(lib_file: impl AsRef<Path>) -> Result<Library, HotReloaderError> {
412    Ok(unsafe { Library::new(lib_file.as_ref()) }?)
413}
414
415fn hash_file(f: impl AsRef<Path>) -> u32 {
416    fs::read(f.as_ref())
417        .map(|content| crc32fast::hash(&content))
418        .unwrap_or_default()
419}