Skip to main content

heliosdb_proxy/plugins/
hot_reload.rs

1//! Hot Reload Support
2//!
3//! File watching and automatic plugin reloading.
4
5use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7use std::sync::Arc;
8use std::time::{Duration, Instant, SystemTime};
9
10use parking_lot::RwLock;
11
12use super::runtime::PluginError;
13
14/// Hot reloader for plugins
15pub struct HotReloader {
16    /// Watch directory
17    watch_dir: PathBuf,
18
19    /// File modification times
20    file_times: RwLock<HashMap<PathBuf, SystemTime>>,
21
22    /// Plugin name to file mapping
23    plugin_files: RwLock<HashMap<String, PathBuf>>,
24
25    /// Debounce duration (ignore rapid changes)
26    debounce: Duration,
27
28    /// Last check time
29    last_check: RwLock<Instant>,
30
31    /// Minimum interval between checks
32    check_interval: Duration,
33
34    /// Pending events (debounced)
35    pending_events: RwLock<HashMap<PathBuf, (ReloadEventType, Instant)>>,
36}
37
38impl HotReloader {
39    /// Create a new hot reloader
40    pub fn new(watch_dir: &Path) -> Result<Self, PluginError> {
41        if !watch_dir.exists() {
42            return Err(PluginError::LoadError(format!(
43                "Watch directory does not exist: {}",
44                watch_dir.display()
45            )));
46        }
47
48        Ok(Self {
49            watch_dir: watch_dir.to_path_buf(),
50            file_times: RwLock::new(HashMap::new()),
51            plugin_files: RwLock::new(HashMap::new()),
52            debounce: Duration::from_millis(500),
53            last_check: RwLock::new(Instant::now()),
54            check_interval: Duration::from_millis(100),
55            pending_events: RwLock::new(HashMap::new()),
56        })
57    }
58
59    /// Register a plugin file
60    pub fn register(&self, plugin_name: &str, path: &Path) {
61        let mut plugin_files = self.plugin_files.write();
62        plugin_files.insert(plugin_name.to_string(), path.to_path_buf());
63
64        // Record initial modification time
65        if let Ok(metadata) = std::fs::metadata(path) {
66            if let Ok(modified) = metadata.modified() {
67                let mut file_times = self.file_times.write();
68                file_times.insert(path.to_path_buf(), modified);
69            }
70        }
71    }
72
73    /// Unregister a plugin file
74    pub fn unregister(&self, plugin_name: &str) {
75        let mut plugin_files = self.plugin_files.write();
76        if let Some(path) = plugin_files.remove(plugin_name) {
77            let mut file_times = self.file_times.write();
78            file_times.remove(&path);
79        }
80    }
81
82    /// Check for file changes
83    pub fn check(&self) -> Result<Vec<ReloadEvent>, PluginError> {
84        let now = Instant::now();
85
86        // Rate limit checks
87        {
88            let last = *self.last_check.read();
89            if now.duration_since(last) < self.check_interval {
90                return Ok(Vec::new());
91            }
92            *self.last_check.write() = now;
93        }
94
95        let mut events = Vec::new();
96
97        // Scan watch directory for new/removed files
98        events.extend(self.scan_directory()?);
99
100        // Check registered files for modifications
101        events.extend(self.check_modifications()?);
102
103        // Process pending events (apply debouncing)
104        events.extend(self.process_pending_events(now)?);
105
106        Ok(events)
107    }
108
109    /// Scan directory for new/removed files
110    fn scan_directory(&self) -> Result<Vec<ReloadEvent>, PluginError> {
111        let mut events = Vec::new();
112
113        if !self.watch_dir.exists() {
114            return Ok(events);
115        }
116
117        let entries = std::fs::read_dir(&self.watch_dir)
118            .map_err(|e| PluginError::RuntimeError(e.to_string()))?;
119
120        let mut current_files: HashMap<PathBuf, SystemTime> = HashMap::new();
121
122        for entry in entries.flatten() {
123            let path = entry.path();
124
125            // Only watch .wasm files
126            if path.extension().map(|e| e != "wasm").unwrap_or(true) {
127                continue;
128            }
129
130            if let Ok(metadata) = std::fs::metadata(&path) {
131                if let Ok(modified) = metadata.modified() {
132                    current_files.insert(path, modified);
133                }
134            }
135        }
136
137        // Check for new files
138        let file_times = self.file_times.read();
139        for path in current_files.keys() {
140            if !file_times.contains_key(path) {
141                // New file detected - add to pending
142                self.add_pending_event(path.clone(), ReloadEventType::Added);
143            }
144        }
145
146        // Check for removed files
147        for path in file_times.keys() {
148            if path.starts_with(&self.watch_dir) && !current_files.contains_key(path) {
149                // File removed
150                if let Some(name) = self.get_plugin_name(path) {
151                    events.push(ReloadEvent::Removed(name));
152                }
153            }
154        }
155
156        Ok(events)
157    }
158
159    /// Check registered files for modifications
160    fn check_modifications(&self) -> Result<Vec<ReloadEvent>, PluginError> {
161        let plugin_files = self.plugin_files.read();
162        let file_times = self.file_times.read();
163
164        for (_plugin_name, path) in plugin_files.iter() {
165            if let Ok(metadata) = std::fs::metadata(path) {
166                if let Ok(modified) = metadata.modified() {
167                    if let Some(old_time) = file_times.get(path) {
168                        if modified > *old_time {
169                            // File modified - add to pending
170                            self.add_pending_event(path.clone(), ReloadEventType::Modified);
171                        }
172                    }
173                }
174            }
175        }
176
177        Ok(Vec::new())
178    }
179
180    /// Add a pending event (for debouncing)
181    fn add_pending_event(&self, path: PathBuf, event_type: ReloadEventType) {
182        let mut pending = self.pending_events.write();
183        pending.insert(path, (event_type, Instant::now()));
184    }
185
186    /// Process pending events after debounce period
187    fn process_pending_events(&self, now: Instant) -> Result<Vec<ReloadEvent>, PluginError> {
188        let mut events = Vec::new();
189        let mut to_remove = Vec::new();
190
191        {
192            let pending = self.pending_events.read();
193            for (path, (event_type, timestamp)) in pending.iter() {
194                if now.duration_since(*timestamp) >= self.debounce {
195                    match event_type {
196                        ReloadEventType::Modified => {
197                            if let Some(name) = self.get_plugin_name(path) {
198                                events.push(ReloadEvent::Modified(name));
199                            }
200                        }
201                        ReloadEventType::Added => {
202                            events.push(ReloadEvent::Added(path.clone()));
203                        }
204                        ReloadEventType::Removed => {
205                            if let Some(name) = self.get_plugin_name(path) {
206                                events.push(ReloadEvent::Removed(name));
207                            }
208                        }
209                    }
210                    to_remove.push(path.clone());
211                }
212            }
213        }
214
215        // Remove processed events and update file times
216        {
217            let mut pending = self.pending_events.write();
218            let mut file_times = self.file_times.write();
219
220            for path in to_remove {
221                pending.remove(&path);
222
223                // Update file time
224                if let Ok(metadata) = std::fs::metadata(&path) {
225                    if let Ok(modified) = metadata.modified() {
226                        file_times.insert(path, modified);
227                    }
228                }
229            }
230        }
231
232        Ok(events)
233    }
234
235    /// Get plugin name for a path
236    fn get_plugin_name(&self, path: &Path) -> Option<String> {
237        let plugin_files = self.plugin_files.read();
238        for (name, p) in plugin_files.iter() {
239            if p == path {
240                return Some(name.clone());
241            }
242        }
243
244        // Fall back to filename
245        path.file_stem()
246            .and_then(|s| s.to_str())
247            .map(|s| s.to_string())
248    }
249
250    /// Set debounce duration
251    pub fn set_debounce(&mut self, duration: Duration) {
252        self.debounce = duration;
253    }
254
255    /// Set check interval
256    pub fn set_check_interval(&mut self, interval: Duration) {
257        self.check_interval = interval;
258    }
259
260    /// Get watch directory
261    pub fn watch_dir(&self) -> &Path {
262        &self.watch_dir
263    }
264
265    /// Get registered plugin count
266    pub fn plugin_count(&self) -> usize {
267        self.plugin_files.read().len()
268    }
269}
270
271/// Reload event type (internal)
272#[derive(Debug, Clone, PartialEq, Eq)]
273enum ReloadEventType {
274    Modified,
275    Added,
276    #[allow(dead_code)]
277    Removed,
278}
279
280/// Reload event
281#[derive(Debug, Clone)]
282pub enum ReloadEvent {
283    /// Plugin file was modified
284    Modified(String),
285
286    /// Plugin file was removed
287    #[allow(dead_code)]
288    Removed(String),
289
290    /// New plugin file was added
291    Added(PathBuf),
292}
293
294/// Reload error
295#[derive(Debug, Clone)]
296pub enum ReloadError {
297    /// File system error
298    FileSystemError(String),
299
300    /// Plugin load error
301    LoadError(String),
302
303    /// Plugin unload error
304    UnloadError(String),
305}
306
307impl std::fmt::Display for ReloadError {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        match self {
310            ReloadError::FileSystemError(msg) => write!(f, "File system error: {}", msg),
311            ReloadError::LoadError(msg) => write!(f, "Load error: {}", msg),
312            ReloadError::UnloadError(msg) => write!(f, "Unload error: {}", msg),
313        }
314    }
315}
316
317impl std::error::Error for ReloadError {}
318
319/// Hot reload watcher (for async watching)
320pub struct HotReloadWatcher {
321    /// Reloader
322    reloader: Arc<HotReloader>,
323
324    /// Running flag
325    running: Arc<std::sync::atomic::AtomicBool>,
326}
327
328impl HotReloadWatcher {
329    /// Create a new watcher
330    pub fn new(reloader: Arc<HotReloader>) -> Self {
331        Self {
332            reloader,
333            running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
334        }
335    }
336
337    /// Start watching (returns immediately, runs in background)
338    pub fn start<F>(&self, callback: F)
339    where
340        F: Fn(Vec<ReloadEvent>) + Send + 'static,
341    {
342        self.running
343            .store(true, std::sync::atomic::Ordering::SeqCst);
344
345        let reloader = self.reloader.clone();
346        let running = self.running.clone();
347
348        std::thread::spawn(move || {
349            while running.load(std::sync::atomic::Ordering::SeqCst) {
350                if let Ok(events) = reloader.check() {
351                    if !events.is_empty() {
352                        callback(events);
353                    }
354                }
355
356                std::thread::sleep(Duration::from_millis(100));
357            }
358        });
359    }
360
361    /// Stop watching
362    pub fn stop(&self) {
363        self.running
364            .store(false, std::sync::atomic::Ordering::SeqCst);
365    }
366
367    /// Check if running
368    pub fn is_running(&self) -> bool {
369        self.running.load(std::sync::atomic::Ordering::SeqCst)
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use std::fs;
377
378    #[test]
379    fn test_hot_reloader_new() {
380        let temp_dir = std::env::temp_dir().join("hot_reload_test");
381        fs::create_dir_all(&temp_dir).unwrap();
382
383        let reloader = HotReloader::new(&temp_dir);
384        assert!(reloader.is_ok());
385
386        fs::remove_dir_all(&temp_dir).ok();
387    }
388
389    #[test]
390    fn test_hot_reloader_nonexistent_dir() {
391        let path = PathBuf::from("/nonexistent/path/to/plugins");
392        let reloader = HotReloader::new(&path);
393        assert!(reloader.is_err());
394    }
395
396    #[test]
397    fn test_hot_reloader_register() {
398        let temp_dir = std::env::temp_dir().join("hot_reload_register_test");
399        fs::create_dir_all(&temp_dir).unwrap();
400
401        let reloader = HotReloader::new(&temp_dir).unwrap();
402
403        let plugin_path = temp_dir.join("test-plugin.wasm");
404        fs::write(&plugin_path, b"\x00asm\x01\x00\x00\x00").unwrap();
405
406        reloader.register("test-plugin", &plugin_path);
407        assert_eq!(reloader.plugin_count(), 1);
408
409        reloader.unregister("test-plugin");
410        assert_eq!(reloader.plugin_count(), 0);
411
412        fs::remove_dir_all(&temp_dir).ok();
413    }
414
415    #[test]
416    fn test_reload_event() {
417        let event = ReloadEvent::Modified("test".to_string());
418        assert!(matches!(event, ReloadEvent::Modified(_)));
419
420        let event = ReloadEvent::Added(PathBuf::from("/test.wasm"));
421        assert!(matches!(event, ReloadEvent::Added(_)));
422
423        let event = ReloadEvent::Removed("test".to_string());
424        assert!(matches!(event, ReloadEvent::Removed(_)));
425    }
426
427    #[test]
428    fn test_reload_error_display() {
429        let err = ReloadError::FileSystemError("test".to_string());
430        assert!(err.to_string().contains("File system error"));
431
432        let err = ReloadError::LoadError("test".to_string());
433        assert!(err.to_string().contains("Load error"));
434    }
435
436    #[test]
437    fn test_hot_reloader_check() {
438        let temp_dir = std::env::temp_dir().join("hot_reload_check_test");
439        fs::create_dir_all(&temp_dir).unwrap();
440
441        let reloader = HotReloader::new(&temp_dir).unwrap();
442
443        // Initial check should return empty
444        let events = reloader.check().unwrap();
445        assert!(events.is_empty());
446
447        fs::remove_dir_all(&temp_dir).ok();
448    }
449
450    #[test]
451    fn test_hot_reload_watcher() {
452        let temp_dir = std::env::temp_dir().join("hot_reload_watcher_test");
453        fs::create_dir_all(&temp_dir).unwrap();
454
455        let reloader = Arc::new(HotReloader::new(&temp_dir).unwrap());
456        let watcher = HotReloadWatcher::new(reloader);
457
458        assert!(!watcher.is_running());
459
460        fs::remove_dir_all(&temp_dir).ok();
461    }
462
463    #[test]
464    fn test_debounce_setting() {
465        let temp_dir = std::env::temp_dir().join("hot_reload_debounce_test");
466        fs::create_dir_all(&temp_dir).unwrap();
467
468        let mut reloader = HotReloader::new(&temp_dir).unwrap();
469        reloader.set_debounce(Duration::from_secs(1));
470        reloader.set_check_interval(Duration::from_millis(50));
471
472        assert_eq!(reloader.debounce, Duration::from_secs(1));
473        assert_eq!(reloader.check_interval, Duration::from_millis(50));
474
475        fs::remove_dir_all(&temp_dir).ok();
476    }
477}