Skip to main content

heliosdb_proxy/plugins/
hot_reload.rs

1//! Hot Reload Support
2//!
3//! File watching and automatic plugin reloading.
4
5use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7use std::sync::Arc;
8use std::time::{Duration, Instant, SystemTime};
9
10use parking_lot::RwLock;
11
12use super::runtime::PluginError;
13
14/// Hot reloader for plugins
15pub struct HotReloader {
16    /// Watch directory
17    watch_dir: PathBuf,
18
19    /// File modification times
20    file_times: RwLock<HashMap<PathBuf, SystemTime>>,
21
22    /// Plugin name to file mapping
23    plugin_files: RwLock<HashMap<String, PathBuf>>,
24
25    /// Debounce duration (ignore rapid changes)
26    debounce: Duration,
27
28    /// Last check time
29    last_check: RwLock<Instant>,
30
31    /// Minimum interval between checks
32    check_interval: Duration,
33
34    /// Pending events (debounced)
35    pending_events: RwLock<HashMap<PathBuf, (ReloadEventType, Instant)>>,
36}
37
38impl HotReloader {
39    /// Create a new hot reloader
40    pub fn new(watch_dir: &Path) -> Result<Self, PluginError> {
41        if !watch_dir.exists() {
42            return Err(PluginError::LoadError(format!(
43                "Watch directory does not exist: {}",
44                watch_dir.display()
45            )));
46        }
47
48        Ok(Self {
49            watch_dir: watch_dir.to_path_buf(),
50            file_times: RwLock::new(HashMap::new()),
51            plugin_files: RwLock::new(HashMap::new()),
52            debounce: Duration::from_millis(500),
53            last_check: RwLock::new(Instant::now()),
54            check_interval: Duration::from_millis(100),
55            pending_events: RwLock::new(HashMap::new()),
56        })
57    }
58
59    /// Register a plugin file
60    pub fn register(&self, plugin_name: &str, path: &Path) {
61        let mut plugin_files = self.plugin_files.write();
62        plugin_files.insert(plugin_name.to_string(), path.to_path_buf());
63
64        // Record initial modification time
65        if let Ok(metadata) = std::fs::metadata(path) {
66            if let Ok(modified) = metadata.modified() {
67                let mut file_times = self.file_times.write();
68                file_times.insert(path.to_path_buf(), modified);
69            }
70        }
71    }
72
73    /// Unregister a plugin file
74    pub fn unregister(&self, plugin_name: &str) {
75        let mut plugin_files = self.plugin_files.write();
76        if let Some(path) = plugin_files.remove(plugin_name) {
77            let mut file_times = self.file_times.write();
78            file_times.remove(&path);
79        }
80    }
81
82    /// Check for file changes
83    pub fn check(&self) -> Result<Vec<ReloadEvent>, PluginError> {
84        let now = Instant::now();
85
86        // Rate limit checks
87        {
88            let last = *self.last_check.read();
89            if now.duration_since(last) < self.check_interval {
90                return Ok(Vec::new());
91            }
92            *self.last_check.write() = now;
93        }
94
95        let mut events = Vec::new();
96
97        // Scan watch directory for new/removed files
98        events.extend(self.scan_directory()?);
99
100        // Check registered files for modifications
101        events.extend(self.check_modifications()?);
102
103        // Process pending events (apply debouncing)
104        events.extend(self.process_pending_events(now)?);
105
106        Ok(events)
107    }
108
109    /// Scan directory for new/removed files
110    fn scan_directory(&self) -> Result<Vec<ReloadEvent>, PluginError> {
111        let mut events = Vec::new();
112
113        if !self.watch_dir.exists() {
114            return Ok(events);
115        }
116
117        let entries = std::fs::read_dir(&self.watch_dir)
118            .map_err(|e| PluginError::RuntimeError(e.to_string()))?;
119
120        let mut current_files: HashMap<PathBuf, SystemTime> = HashMap::new();
121
122        for entry in entries.flatten() {
123            let path = entry.path();
124
125            // Only watch .wasm files
126            if path.extension().map(|e| e != "wasm").unwrap_or(true) {
127                continue;
128            }
129
130            if let Ok(metadata) = std::fs::metadata(&path) {
131                if let Ok(modified) = metadata.modified() {
132                    current_files.insert(path, modified);
133                }
134            }
135        }
136
137        // Check for new files
138        let file_times = self.file_times.read();
139        for (path, _) in &current_files {
140            if !file_times.contains_key(path) {
141                // New file detected - add to pending
142                self.add_pending_event(path.clone(), ReloadEventType::Added);
143            }
144        }
145
146        // Check for removed files
147        for path in file_times.keys() {
148            if path.starts_with(&self.watch_dir) && !current_files.contains_key(path) {
149                // File removed
150                if let Some(name) = self.get_plugin_name(path) {
151                    events.push(ReloadEvent::Removed(name));
152                }
153            }
154        }
155
156        Ok(events)
157    }
158
159    /// Check registered files for modifications
160    fn check_modifications(&self) -> Result<Vec<ReloadEvent>, PluginError> {
161        let plugin_files = self.plugin_files.read();
162        let file_times = self.file_times.read();
163
164        for (_plugin_name, path) in plugin_files.iter() {
165            if let Ok(metadata) = std::fs::metadata(path) {
166                if let Ok(modified) = metadata.modified() {
167                    if let Some(old_time) = file_times.get(path) {
168                        if modified > *old_time {
169                            // File modified - add to pending
170                            self.add_pending_event(path.clone(), ReloadEventType::Modified);
171                        }
172                    }
173                }
174            }
175        }
176
177        Ok(Vec::new())
178    }
179
180    /// Add a pending event (for debouncing)
181    fn add_pending_event(&self, path: PathBuf, event_type: ReloadEventType) {
182        let mut pending = self.pending_events.write();
183        pending.insert(path, (event_type, Instant::now()));
184    }
185
186    /// Process pending events after debounce period
187    fn process_pending_events(&self, now: Instant) -> Result<Vec<ReloadEvent>, PluginError> {
188        let mut events = Vec::new();
189        let mut to_remove = Vec::new();
190
191        {
192            let pending = self.pending_events.read();
193            for (path, (event_type, timestamp)) in pending.iter() {
194                if now.duration_since(*timestamp) >= self.debounce {
195                    match event_type {
196                        ReloadEventType::Modified => {
197                            if let Some(name) = self.get_plugin_name(path) {
198                                events.push(ReloadEvent::Modified(name));
199                            }
200                        }
201                        ReloadEventType::Added => {
202                            events.push(ReloadEvent::Added(path.clone()));
203                        }
204                        ReloadEventType::Removed => {
205                            if let Some(name) = self.get_plugin_name(path) {
206                                events.push(ReloadEvent::Removed(name));
207                            }
208                        }
209                    }
210                    to_remove.push(path.clone());
211                }
212            }
213        }
214
215        // Remove processed events and update file times
216        {
217            let mut pending = self.pending_events.write();
218            let mut file_times = self.file_times.write();
219
220            for path in to_remove {
221                pending.remove(&path);
222
223                // Update file time
224                if let Ok(metadata) = std::fs::metadata(&path) {
225                    if let Ok(modified) = metadata.modified() {
226                        file_times.insert(path, modified);
227                    }
228                }
229            }
230        }
231
232        Ok(events)
233    }
234
235    /// Get plugin name for a path
236    fn get_plugin_name(&self, path: &Path) -> Option<String> {
237        let plugin_files = self.plugin_files.read();
238        for (name, p) in plugin_files.iter() {
239            if p == path {
240                return Some(name.clone());
241            }
242        }
243
244        // Fall back to filename
245        path.file_stem().and_then(|s| s.to_str()).map(|s| s.to_string())
246    }
247
248    /// Set debounce duration
249    pub fn set_debounce(&mut self, duration: Duration) {
250        self.debounce = duration;
251    }
252
253    /// Set check interval
254    pub fn set_check_interval(&mut self, interval: Duration) {
255        self.check_interval = interval;
256    }
257
258    /// Get watch directory
259    pub fn watch_dir(&self) -> &Path {
260        &self.watch_dir
261    }
262
263    /// Get registered plugin count
264    pub fn plugin_count(&self) -> usize {
265        self.plugin_files.read().len()
266    }
267}
268
269/// Reload event type (internal)
270#[derive(Debug, Clone, PartialEq, Eq)]
271enum ReloadEventType {
272    Modified,
273    Added,
274    Removed,
275}
276
277/// Reload event
278#[derive(Debug, Clone)]
279pub enum ReloadEvent {
280    /// Plugin file was modified
281    Modified(String),
282
283    /// Plugin file was removed
284    Removed(String),
285
286    /// New plugin file was added
287    Added(PathBuf),
288}
289
290/// Reload error
291#[derive(Debug, Clone)]
292pub enum ReloadError {
293    /// File system error
294    FileSystemError(String),
295
296    /// Plugin load error
297    LoadError(String),
298
299    /// Plugin unload error
300    UnloadError(String),
301}
302
303impl std::fmt::Display for ReloadError {
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        match self {
306            ReloadError::FileSystemError(msg) => write!(f, "File system error: {}", msg),
307            ReloadError::LoadError(msg) => write!(f, "Load error: {}", msg),
308            ReloadError::UnloadError(msg) => write!(f, "Unload error: {}", msg),
309        }
310    }
311}
312
313impl std::error::Error for ReloadError {}
314
315/// Hot reload watcher (for async watching)
316pub struct HotReloadWatcher {
317    /// Reloader
318    reloader: Arc<HotReloader>,
319
320    /// Running flag
321    running: Arc<std::sync::atomic::AtomicBool>,
322}
323
324impl HotReloadWatcher {
325    /// Create a new watcher
326    pub fn new(reloader: Arc<HotReloader>) -> Self {
327        Self {
328            reloader,
329            running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
330        }
331    }
332
333    /// Start watching (returns immediately, runs in background)
334    pub fn start<F>(&self, callback: F)
335    where
336        F: Fn(Vec<ReloadEvent>) + Send + 'static,
337    {
338        self.running.store(true, std::sync::atomic::Ordering::SeqCst);
339
340        let reloader = self.reloader.clone();
341        let running = self.running.clone();
342
343        std::thread::spawn(move || {
344            while running.load(std::sync::atomic::Ordering::SeqCst) {
345                if let Ok(events) = reloader.check() {
346                    if !events.is_empty() {
347                        callback(events);
348                    }
349                }
350
351                std::thread::sleep(Duration::from_millis(100));
352            }
353        });
354    }
355
356    /// Stop watching
357    pub fn stop(&self) {
358        self.running.store(false, std::sync::atomic::Ordering::SeqCst);
359    }
360
361    /// Check if running
362    pub fn is_running(&self) -> bool {
363        self.running.load(std::sync::atomic::Ordering::SeqCst)
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use std::fs;
371
372    #[test]
373    fn test_hot_reloader_new() {
374        let temp_dir = std::env::temp_dir().join("hot_reload_test");
375        fs::create_dir_all(&temp_dir).unwrap();
376
377        let reloader = HotReloader::new(&temp_dir);
378        assert!(reloader.is_ok());
379
380        fs::remove_dir_all(&temp_dir).ok();
381    }
382
383    #[test]
384    fn test_hot_reloader_nonexistent_dir() {
385        let path = PathBuf::from("/nonexistent/path/to/plugins");
386        let reloader = HotReloader::new(&path);
387        assert!(reloader.is_err());
388    }
389
390    #[test]
391    fn test_hot_reloader_register() {
392        let temp_dir = std::env::temp_dir().join("hot_reload_register_test");
393        fs::create_dir_all(&temp_dir).unwrap();
394
395        let reloader = HotReloader::new(&temp_dir).unwrap();
396
397        let plugin_path = temp_dir.join("test-plugin.wasm");
398        fs::write(&plugin_path, b"\x00asm\x01\x00\x00\x00").unwrap();
399
400        reloader.register("test-plugin", &plugin_path);
401        assert_eq!(reloader.plugin_count(), 1);
402
403        reloader.unregister("test-plugin");
404        assert_eq!(reloader.plugin_count(), 0);
405
406        fs::remove_dir_all(&temp_dir).ok();
407    }
408
409    #[test]
410    fn test_reload_event() {
411        let event = ReloadEvent::Modified("test".to_string());
412        assert!(matches!(event, ReloadEvent::Modified(_)));
413
414        let event = ReloadEvent::Added(PathBuf::from("/test.wasm"));
415        assert!(matches!(event, ReloadEvent::Added(_)));
416
417        let event = ReloadEvent::Removed("test".to_string());
418        assert!(matches!(event, ReloadEvent::Removed(_)));
419    }
420
421    #[test]
422    fn test_reload_error_display() {
423        let err = ReloadError::FileSystemError("test".to_string());
424        assert!(err.to_string().contains("File system error"));
425
426        let err = ReloadError::LoadError("test".to_string());
427        assert!(err.to_string().contains("Load error"));
428    }
429
430    #[test]
431    fn test_hot_reloader_check() {
432        let temp_dir = std::env::temp_dir().join("hot_reload_check_test");
433        fs::create_dir_all(&temp_dir).unwrap();
434
435        let reloader = HotReloader::new(&temp_dir).unwrap();
436
437        // Initial check should return empty
438        let events = reloader.check().unwrap();
439        assert!(events.is_empty());
440
441        fs::remove_dir_all(&temp_dir).ok();
442    }
443
444    #[test]
445    fn test_hot_reload_watcher() {
446        let temp_dir = std::env::temp_dir().join("hot_reload_watcher_test");
447        fs::create_dir_all(&temp_dir).unwrap();
448
449        let reloader = Arc::new(HotReloader::new(&temp_dir).unwrap());
450        let watcher = HotReloadWatcher::new(reloader);
451
452        assert!(!watcher.is_running());
453
454        fs::remove_dir_all(&temp_dir).ok();
455    }
456
457    #[test]
458    fn test_debounce_setting() {
459        let temp_dir = std::env::temp_dir().join("hot_reload_debounce_test");
460        fs::create_dir_all(&temp_dir).unwrap();
461
462        let mut reloader = HotReloader::new(&temp_dir).unwrap();
463        reloader.set_debounce(Duration::from_secs(1));
464        reloader.set_check_interval(Duration::from_millis(50));
465
466        assert_eq!(reloader.debounce, Duration::from_secs(1));
467        assert_eq!(reloader.check_interval, Duration::from_millis(50));
468
469        fs::remove_dir_all(&temp_dir).ok();
470    }
471}