Skip to main content

mofa_plugins/hot_reload/
loader.rs

1//! Dynamic plugin loader
2//!
3//! Handles loading and unloading of dynamic plugins from shared libraries
4
5use crate::{AgentPlugin, PluginMetadata};
6use libloading::{Library, Symbol};
7use sha2::{Digest, Sha256};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use tracing::{debug, info};
13
14/// Plugin load error types
15#[derive(Debug, thiserror::Error)]
16pub enum PluginLoadError {
17    #[error("Failed to load library: {0}")]
18    LibraryLoad(String),
19
20    #[error("Symbol not found: {0}")]
21    SymbolNotFound(String),
22
23    #[error("Plugin creation failed: {0}")]
24    CreationFailed(String),
25
26    #[error("Invalid plugin: {0}")]
27    InvalidPlugin(String),
28
29    #[error("Version mismatch: expected {expected}, got {actual}")]
30    VersionMismatch { expected: String, actual: String },
31
32    #[error("IO error: {0}")]
33    IoError(#[from] std::io::Error),
34
35    #[error("Plugin already loaded: {0}")]
36    AlreadyLoaded(String),
37
38    #[error("Plugin not found: {0}")]
39    NotFound(String),
40}
41
42/// Required symbols for a plugin library
43pub struct PluginSymbols {
44    /// Create plugin function
45    pub create: Symbol<'static, unsafe extern "C" fn() -> *mut dyn AgentPlugin>,
46    /// Destroy plugin function
47    pub destroy: Symbol<'static, unsafe extern "C" fn(*mut dyn AgentPlugin)>,
48    /// Get plugin metadata function
49    pub metadata: Symbol<'static, unsafe extern "C" fn() -> PluginMetadata>,
50    /// Get API version function
51    pub api_version: Symbol<'static, unsafe extern "C" fn() -> u32>,
52}
53
54/// Represents a loaded plugin library
55pub struct PluginLibrary {
56    /// Path to the library file
57    path: PathBuf,
58    /// The loaded library
59    library: Library,
60    /// File hash for change detection
61    hash: String,
62    /// Load timestamp
63    loaded_at: std::time::Instant,
64    /// Plugin metadata
65    metadata: PluginMetadata,
66    /// API version
67    api_version: u32,
68}
69
70impl PluginLibrary {
71    /// Get the library path
72    pub fn path(&self) -> &Path {
73        &self.path
74    }
75
76    /// Get the file hash
77    pub fn hash(&self) -> &str {
78        &self.hash
79    }
80
81    /// Get when the library was loaded
82    pub fn loaded_at(&self) -> std::time::Instant {
83        self.loaded_at
84    }
85
86    /// Get plugin metadata
87    pub fn metadata(&self) -> &PluginMetadata {
88        &self.metadata
89    }
90
91    /// Get API version
92    pub fn api_version(&self) -> u32 {
93        self.api_version
94    }
95
96    /// Create a new plugin instance from this library
97    ///
98    /// # Safety
99    /// This function calls extern "C" functions from a dynamic library
100    pub unsafe fn create_instance(&self) -> Result<Box<dyn AgentPlugin>, PluginLoadError> {
101        unsafe {
102            let create_fn: Symbol<unsafe extern "C" fn() -> *mut dyn AgentPlugin> = self
103                .library
104                .get(b"_plugin_create")
105                .map_err(|e| PluginLoadError::SymbolNotFound(format!("_plugin_create: {}", e)))?;
106
107            let raw_plugin = create_fn();
108            if raw_plugin.is_null() {
109                return Err(PluginLoadError::CreationFailed(
110                    "Plugin creation returned null".to_string(),
111                ));
112            }
113
114            Ok(Box::from_raw(raw_plugin))
115        }
116    }
117
118    /// Destroy a plugin instance
119    ///
120    /// # Safety
121    /// This function calls extern "C" functions from a dynamic library
122    pub unsafe fn destroy_instance(
123        &self,
124        plugin: Box<dyn AgentPlugin>,
125    ) -> Result<(), PluginLoadError> {
126        unsafe {
127            let destroy_fn: Symbol<unsafe extern "C" fn(*mut dyn AgentPlugin)> = self
128                .library
129                .get(b"_plugin_destroy")
130                .map_err(|e| PluginLoadError::SymbolNotFound(format!("_plugin_destroy: {}", e)))?;
131
132            let raw = Box::into_raw(plugin);
133            destroy_fn(raw);
134            Ok(())
135        }
136    }
137}
138
139impl Drop for PluginLibrary {
140    fn drop(&mut self) {
141        debug!("Unloading plugin library: {:?}", self.path);
142    }
143}
144
145/// A dynamically loaded plugin wrapper
146pub struct DynamicPlugin {
147    /// The plugin instance
148    plugin: Box<dyn AgentPlugin>,
149    /// Reference to the library
150    library_path: PathBuf,
151    /// Instance ID for tracking
152    instance_id: String,
153    /// Creation time
154    created_at: std::time::Instant,
155}
156
157impl DynamicPlugin {
158    /// Create a new dynamic plugin
159    pub fn new(plugin: Box<dyn AgentPlugin>, library_path: PathBuf) -> Self {
160        Self {
161            plugin,
162            library_path,
163            instance_id: uuid::Uuid::now_v7().to_string(),
164            created_at: std::time::Instant::now(),
165        }
166    }
167
168    /// Get the inner plugin
169    pub fn plugin(&self) -> &dyn AgentPlugin {
170        self.plugin.as_ref()
171    }
172
173    /// Get mutable reference to the inner plugin
174    pub fn plugin_mut(&mut self) -> &mut dyn AgentPlugin {
175        self.plugin.as_mut()
176    }
177
178    /// Get the library path
179    pub fn library_path(&self) -> &Path {
180        &self.library_path
181    }
182
183    /// Get the instance ID
184    pub fn instance_id(&self) -> &str {
185        &self.instance_id
186    }
187
188    /// Get creation time
189    pub fn created_at(&self) -> std::time::Instant {
190        self.created_at
191    }
192
193    /// Consume and return the inner plugin
194    pub fn into_inner(self) -> Box<dyn AgentPlugin> {
195        self.plugin
196    }
197}
198
199/// Plugin loader for managing dynamic plugin loading
200pub struct PluginLoader {
201    /// Loaded libraries
202    libraries: Arc<RwLock<HashMap<PathBuf, Arc<PluginLibrary>>>>,
203    /// Plugin search paths
204    search_paths: Vec<PathBuf>,
205    /// Expected API version
206    api_version: u32,
207    /// Enable unsafe loading (skip validation)
208    unsafe_mode: bool,
209}
210
211impl PluginLoader {
212    /// Current API version
213    pub const CURRENT_API_VERSION: u32 = 1;
214
215    /// Create a new plugin loader
216    pub fn new() -> Self {
217        Self {
218            libraries: Arc::new(RwLock::new(HashMap::new())),
219            search_paths: Vec::new(),
220            api_version: Self::CURRENT_API_VERSION,
221            unsafe_mode: false,
222        }
223    }
224
225    /// Add a search path for plugins
226    pub fn add_search_path<P: AsRef<Path>>(&mut self, path: P) {
227        self.search_paths.push(path.as_ref().to_path_buf());
228    }
229
230    /// Enable unsafe mode (skip validation)
231    pub fn set_unsafe_mode(&mut self, enabled: bool) {
232        self.unsafe_mode = enabled;
233    }
234
235    /// Calculate file hash
236    fn calculate_hash(path: &Path) -> Result<String, PluginLoadError> {
237        let contents = std::fs::read(path)?;
238        let mut hasher = Sha256::new();
239        hasher.update(&contents);
240        Ok(format!("{:x}", hasher.finalize()))
241    }
242
243    /// Find plugin file by name
244    pub fn find_plugin(&self, name: &str) -> Option<PathBuf> {
245        let lib_name = if cfg!(target_os = "windows") {
246            format!("{}.dll", name)
247        } else if cfg!(target_os = "macos") {
248            format!("lib{}.dylib", name)
249        } else {
250            format!("lib{}.so", name)
251        };
252
253        // First check if it's an absolute path
254        let direct_path = PathBuf::from(name);
255        if direct_path.exists() {
256            return Some(direct_path);
257        }
258
259        // Search in search paths
260        for search_path in &self.search_paths {
261            let full_path = search_path.join(&lib_name);
262            if full_path.exists() {
263                return Some(full_path);
264            }
265        }
266
267        // Check current directory
268        let current_path = PathBuf::from(&lib_name);
269        if current_path.exists() {
270            return Some(current_path);
271        }
272
273        None
274    }
275
276    /// Load a plugin library from file
277    ///
278    /// # Safety
279    /// Loading dynamic libraries is inherently unsafe
280    pub async fn load_library<P: AsRef<Path>>(
281        &self,
282        path: P,
283    ) -> Result<Arc<PluginLibrary>, PluginLoadError> {
284        let path = path.as_ref().to_path_buf();
285
286        // Check if already loaded
287        {
288            let libraries = self.libraries.read().await;
289            if let Some(lib) = libraries.get(&path) {
290                return Ok(lib.clone());
291            }
292        }
293
294        info!("Loading plugin library: {:?}", path);
295
296        // Calculate hash
297        let hash = Self::calculate_hash(&path)?;
298
299        // Load the library
300        let library = unsafe {
301            Library::new(&path).map_err(|e| PluginLoadError::LibraryLoad(e.to_string()))?
302        };
303
304        // Get API version
305        let api_version = unsafe {
306            let version_fn: Result<Symbol<unsafe extern "C" fn() -> u32>, _> =
307                library.get(b"_plugin_api_version");
308
309            match version_fn {
310                Ok(func) => func(),
311                Err(_) => 1, // Default to version 1 if not specified
312            }
313        };
314
315        // Validate API version
316        if !self.unsafe_mode && api_version != self.api_version {
317            return Err(PluginLoadError::VersionMismatch {
318                expected: self.api_version.to_string(),
319                actual: api_version.to_string(),
320            });
321        }
322
323        // Get metadata
324        let metadata = unsafe {
325            let metadata_fn: Symbol<unsafe extern "C" fn() -> PluginMetadata> = library
326                .get(b"_plugin_metadata")
327                .map_err(|e| PluginLoadError::SymbolNotFound(format!("_plugin_metadata: {}", e)))?;
328            metadata_fn()
329        };
330
331        let plugin_lib = Arc::new(PluginLibrary {
332            path: path.clone(),
333            library,
334            hash,
335            loaded_at: std::time::Instant::now(),
336            metadata,
337            api_version,
338        });
339
340        // Store in cache
341        {
342            let mut libraries = self.libraries.write().await;
343            libraries.insert(path.clone(), plugin_lib.clone());
344        }
345
346        info!(
347            "Loaded plugin: {} v{}",
348            plugin_lib.metadata.name, plugin_lib.metadata.version
349        );
350
351        Ok(plugin_lib)
352    }
353
354    /// Unload a plugin library
355    pub async fn unload_library<P: AsRef<Path>>(&self, path: P) -> Result<(), PluginLoadError> {
356        let path = path.as_ref().to_path_buf();
357
358        let mut libraries = self.libraries.write().await;
359        if libraries.remove(&path).is_some() {
360            info!("Unloaded plugin library: {:?}", path);
361            Ok(())
362        } else {
363            Err(PluginLoadError::NotFound(path.display().to_string()))
364        }
365    }
366
367    /// Check if a library has changed
368    pub async fn has_changed<P: AsRef<Path>>(&self, path: P) -> Result<bool, PluginLoadError> {
369        let path = path.as_ref().to_path_buf();
370
371        let libraries = self.libraries.read().await;
372        if let Some(lib) = libraries.get(&path) {
373            let current_hash = Self::calculate_hash(&path)?;
374            Ok(current_hash != lib.hash)
375        } else {
376            Ok(true) // Not loaded, so consider it "changed"
377        }
378    }
379
380    /// Get a loaded library
381    pub async fn get_library<P: AsRef<Path>>(&self, path: P) -> Option<Arc<PluginLibrary>> {
382        let libraries = self.libraries.read().await;
383        libraries.get(path.as_ref()).cloned()
384    }
385
386    /// List all loaded libraries
387    pub async fn list_libraries(&self) -> Vec<PathBuf> {
388        let libraries = self.libraries.read().await;
389        libraries.keys().cloned().collect()
390    }
391
392    /// Create a plugin instance from a loaded library
393    pub async fn create_plugin<P: AsRef<Path>>(
394        &self,
395        path: P,
396    ) -> Result<DynamicPlugin, PluginLoadError> {
397        let path = path.as_ref().to_path_buf();
398
399        let library = self.load_library(&path).await?;
400
401        let plugin = unsafe { library.create_instance()? };
402
403        Ok(DynamicPlugin::new(plugin, path))
404    }
405
406    /// Reload a plugin library
407    pub async fn reload_library<P: AsRef<Path>>(
408        &self,
409        path: P,
410    ) -> Result<Arc<PluginLibrary>, PluginLoadError> {
411        let path = path.as_ref().to_path_buf();
412
413        // Unload existing
414        let _ = self.unload_library(&path).await;
415
416        // Small delay to ensure file handle is released
417        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
418
419        // Load fresh
420        self.load_library(&path).await
421    }
422}
423
424impl Default for PluginLoader {
425    fn default() -> Self {
426        Self::new()
427    }
428}
429
430/// Macro to create a plugin export
431#[macro_export]
432macro_rules! declare_plugin {
433    ($plugin_type:ty, $create_fn:expr) => {
434        #[no_mangle]
435        pub extern "C" fn _plugin_create() -> *mut dyn $crate::plugin::AgentPlugin {
436            let plugin: Box<dyn $crate::plugin::AgentPlugin> = Box::new($create_fn);
437            Box::into_raw(plugin)
438        }
439
440        #[no_mangle]
441        pub extern "C" fn _plugin_destroy(plugin: *mut dyn $crate::plugin::AgentPlugin) {
442            if !plugin.is_null() {
443                unsafe {
444                    let _ = Box::from_raw(plugin);
445                }
446            }
447        }
448
449        #[no_mangle]
450        pub extern "C" fn _plugin_api_version() -> u32 {
451            $crate::hot_reload::PluginLoader::CURRENT_API_VERSION
452        }
453
454        #[no_mangle]
455        pub extern "C" fn _plugin_metadata() -> $crate::plugin::PluginMetadata {
456            let plugin: $plugin_type = $create_fn;
457            plugin.metadata().clone()
458        }
459    };
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[tokio::test]
467    async fn test_plugin_loader_new() {
468        let loader = PluginLoader::new();
469        assert_eq!(loader.api_version, PluginLoader::CURRENT_API_VERSION);
470        assert!(!loader.unsafe_mode);
471    }
472
473    #[tokio::test]
474    async fn test_search_paths() {
475        let mut loader = PluginLoader::new();
476        loader.add_search_path("/usr/lib/plugins");
477        loader.add_search_path("/opt/plugins");
478        assert_eq!(loader.search_paths.len(), 2);
479    }
480
481    #[test]
482    fn test_calculate_hash() {
483        // Create a temp file
484        let temp_dir = tempfile::tempdir().unwrap();
485        let file_path = temp_dir.path().join("test.txt");
486        std::fs::write(&file_path, b"test content").unwrap();
487
488        let hash1 = PluginLoader::calculate_hash(&file_path).unwrap();
489        let hash2 = PluginLoader::calculate_hash(&file_path).unwrap();
490
491        assert_eq!(hash1, hash2);
492        assert!(!hash1.is_empty());
493    }
494}