1use crate::{AgentPlugin, PluginMetadata};
6use libloading::{Library, Symbol};
7use sha2::{Digest, Sha256};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use tracing::{debug, info};
13
14#[derive(Debug, thiserror::Error)]
16pub enum PluginLoadError {
17 #[error("Failed to load library: {0}")]
18 LibraryLoad(String),
19
20 #[error("Symbol not found: {0}")]
21 SymbolNotFound(String),
22
23 #[error("Plugin creation failed: {0}")]
24 CreationFailed(String),
25
26 #[error("Invalid plugin: {0}")]
27 InvalidPlugin(String),
28
29 #[error("Version mismatch: expected {expected}, got {actual}")]
30 VersionMismatch { expected: String, actual: String },
31
32 #[error("IO error: {0}")]
33 IoError(#[from] std::io::Error),
34
35 #[error("Plugin already loaded: {0}")]
36 AlreadyLoaded(String),
37
38 #[error("Plugin not found: {0}")]
39 NotFound(String),
40}
41
42pub struct PluginSymbols {
44 pub create: Symbol<'static, unsafe extern "C" fn() -> *mut dyn AgentPlugin>,
46 pub destroy: Symbol<'static, unsafe extern "C" fn(*mut dyn AgentPlugin)>,
48 pub metadata: Symbol<'static, unsafe extern "C" fn() -> PluginMetadata>,
50 pub api_version: Symbol<'static, unsafe extern "C" fn() -> u32>,
52}
53
54pub struct PluginLibrary {
56 path: PathBuf,
58 library: Library,
60 hash: String,
62 loaded_at: std::time::Instant,
64 metadata: PluginMetadata,
66 api_version: u32,
68}
69
70impl PluginLibrary {
71 pub fn path(&self) -> &Path {
73 &self.path
74 }
75
76 pub fn hash(&self) -> &str {
78 &self.hash
79 }
80
81 pub fn loaded_at(&self) -> std::time::Instant {
83 self.loaded_at
84 }
85
86 pub fn metadata(&self) -> &PluginMetadata {
88 &self.metadata
89 }
90
91 pub fn api_version(&self) -> u32 {
93 self.api_version
94 }
95
96 pub unsafe fn create_instance(&self) -> Result<Box<dyn AgentPlugin>, PluginLoadError> {
101 unsafe {
102 let create_fn: Symbol<unsafe extern "C" fn() -> *mut dyn AgentPlugin> = self
103 .library
104 .get(b"_plugin_create")
105 .map_err(|e| PluginLoadError::SymbolNotFound(format!("_plugin_create: {}", e)))?;
106
107 let raw_plugin = create_fn();
108 if raw_plugin.is_null() {
109 return Err(PluginLoadError::CreationFailed(
110 "Plugin creation returned null".to_string(),
111 ));
112 }
113
114 Ok(Box::from_raw(raw_plugin))
115 }
116 }
117
118 pub unsafe fn destroy_instance(
123 &self,
124 plugin: Box<dyn AgentPlugin>,
125 ) -> Result<(), PluginLoadError> {
126 unsafe {
127 let destroy_fn: Symbol<unsafe extern "C" fn(*mut dyn AgentPlugin)> = self
128 .library
129 .get(b"_plugin_destroy")
130 .map_err(|e| PluginLoadError::SymbolNotFound(format!("_plugin_destroy: {}", e)))?;
131
132 let raw = Box::into_raw(plugin);
133 destroy_fn(raw);
134 Ok(())
135 }
136 }
137}
138
139impl Drop for PluginLibrary {
140 fn drop(&mut self) {
141 debug!("Unloading plugin library: {:?}", self.path);
142 }
143}
144
145pub struct DynamicPlugin {
147 plugin: Box<dyn AgentPlugin>,
149 library_path: PathBuf,
151 instance_id: String,
153 created_at: std::time::Instant,
155}
156
157impl DynamicPlugin {
158 pub fn new(plugin: Box<dyn AgentPlugin>, library_path: PathBuf) -> Self {
160 Self {
161 plugin,
162 library_path,
163 instance_id: uuid::Uuid::now_v7().to_string(),
164 created_at: std::time::Instant::now(),
165 }
166 }
167
168 pub fn plugin(&self) -> &dyn AgentPlugin {
170 self.plugin.as_ref()
171 }
172
173 pub fn plugin_mut(&mut self) -> &mut dyn AgentPlugin {
175 self.plugin.as_mut()
176 }
177
178 pub fn library_path(&self) -> &Path {
180 &self.library_path
181 }
182
183 pub fn instance_id(&self) -> &str {
185 &self.instance_id
186 }
187
188 pub fn created_at(&self) -> std::time::Instant {
190 self.created_at
191 }
192
193 pub fn into_inner(self) -> Box<dyn AgentPlugin> {
195 self.plugin
196 }
197}
198
199pub struct PluginLoader {
201 libraries: Arc<RwLock<HashMap<PathBuf, Arc<PluginLibrary>>>>,
203 search_paths: Vec<PathBuf>,
205 api_version: u32,
207 unsafe_mode: bool,
209}
210
211impl PluginLoader {
212 pub const CURRENT_API_VERSION: u32 = 1;
214
215 pub fn new() -> Self {
217 Self {
218 libraries: Arc::new(RwLock::new(HashMap::new())),
219 search_paths: Vec::new(),
220 api_version: Self::CURRENT_API_VERSION,
221 unsafe_mode: false,
222 }
223 }
224
225 pub fn add_search_path<P: AsRef<Path>>(&mut self, path: P) {
227 self.search_paths.push(path.as_ref().to_path_buf());
228 }
229
230 pub fn set_unsafe_mode(&mut self, enabled: bool) {
232 self.unsafe_mode = enabled;
233 }
234
235 fn calculate_hash(path: &Path) -> Result<String, PluginLoadError> {
237 let contents = std::fs::read(path)?;
238 let mut hasher = Sha256::new();
239 hasher.update(&contents);
240 Ok(format!("{:x}", hasher.finalize()))
241 }
242
243 pub fn find_plugin(&self, name: &str) -> Option<PathBuf> {
245 let lib_name = if cfg!(target_os = "windows") {
246 format!("{}.dll", name)
247 } else if cfg!(target_os = "macos") {
248 format!("lib{}.dylib", name)
249 } else {
250 format!("lib{}.so", name)
251 };
252
253 let direct_path = PathBuf::from(name);
255 if direct_path.exists() {
256 return Some(direct_path);
257 }
258
259 for search_path in &self.search_paths {
261 let full_path = search_path.join(&lib_name);
262 if full_path.exists() {
263 return Some(full_path);
264 }
265 }
266
267 let current_path = PathBuf::from(&lib_name);
269 if current_path.exists() {
270 return Some(current_path);
271 }
272
273 None
274 }
275
276 pub async fn load_library<P: AsRef<Path>>(
281 &self,
282 path: P,
283 ) -> Result<Arc<PluginLibrary>, PluginLoadError> {
284 let path = path.as_ref().to_path_buf();
285
286 {
288 let libraries = self.libraries.read().await;
289 if let Some(lib) = libraries.get(&path) {
290 return Ok(lib.clone());
291 }
292 }
293
294 info!("Loading plugin library: {:?}", path);
295
296 let hash = Self::calculate_hash(&path)?;
298
299 let library = unsafe {
301 Library::new(&path).map_err(|e| PluginLoadError::LibraryLoad(e.to_string()))?
302 };
303
304 let api_version = unsafe {
306 let version_fn: Result<Symbol<unsafe extern "C" fn() -> u32>, _> =
307 library.get(b"_plugin_api_version");
308
309 match version_fn {
310 Ok(func) => func(),
311 Err(_) => 1, }
313 };
314
315 if !self.unsafe_mode && api_version != self.api_version {
317 return Err(PluginLoadError::VersionMismatch {
318 expected: self.api_version.to_string(),
319 actual: api_version.to_string(),
320 });
321 }
322
323 let metadata = unsafe {
325 let metadata_fn: Symbol<unsafe extern "C" fn() -> PluginMetadata> = library
326 .get(b"_plugin_metadata")
327 .map_err(|e| PluginLoadError::SymbolNotFound(format!("_plugin_metadata: {}", e)))?;
328 metadata_fn()
329 };
330
331 let plugin_lib = Arc::new(PluginLibrary {
332 path: path.clone(),
333 library,
334 hash,
335 loaded_at: std::time::Instant::now(),
336 metadata,
337 api_version,
338 });
339
340 {
342 let mut libraries = self.libraries.write().await;
343 libraries.insert(path.clone(), plugin_lib.clone());
344 }
345
346 info!(
347 "Loaded plugin: {} v{}",
348 plugin_lib.metadata.name, plugin_lib.metadata.version
349 );
350
351 Ok(plugin_lib)
352 }
353
354 pub async fn unload_library<P: AsRef<Path>>(&self, path: P) -> Result<(), PluginLoadError> {
356 let path = path.as_ref().to_path_buf();
357
358 let mut libraries = self.libraries.write().await;
359 if libraries.remove(&path).is_some() {
360 info!("Unloaded plugin library: {:?}", path);
361 Ok(())
362 } else {
363 Err(PluginLoadError::NotFound(path.display().to_string()))
364 }
365 }
366
367 pub async fn has_changed<P: AsRef<Path>>(&self, path: P) -> Result<bool, PluginLoadError> {
369 let path = path.as_ref().to_path_buf();
370
371 let libraries = self.libraries.read().await;
372 if let Some(lib) = libraries.get(&path) {
373 let current_hash = Self::calculate_hash(&path)?;
374 Ok(current_hash != lib.hash)
375 } else {
376 Ok(true) }
378 }
379
380 pub async fn get_library<P: AsRef<Path>>(&self, path: P) -> Option<Arc<PluginLibrary>> {
382 let libraries = self.libraries.read().await;
383 libraries.get(path.as_ref()).cloned()
384 }
385
386 pub async fn list_libraries(&self) -> Vec<PathBuf> {
388 let libraries = self.libraries.read().await;
389 libraries.keys().cloned().collect()
390 }
391
392 pub async fn create_plugin<P: AsRef<Path>>(
394 &self,
395 path: P,
396 ) -> Result<DynamicPlugin, PluginLoadError> {
397 let path = path.as_ref().to_path_buf();
398
399 let library = self.load_library(&path).await?;
400
401 let plugin = unsafe { library.create_instance()? };
402
403 Ok(DynamicPlugin::new(plugin, path))
404 }
405
406 pub async fn reload_library<P: AsRef<Path>>(
408 &self,
409 path: P,
410 ) -> Result<Arc<PluginLibrary>, PluginLoadError> {
411 let path = path.as_ref().to_path_buf();
412
413 let _ = self.unload_library(&path).await;
415
416 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
418
419 self.load_library(&path).await
421 }
422}
423
424impl Default for PluginLoader {
425 fn default() -> Self {
426 Self::new()
427 }
428}
429
430#[macro_export]
432macro_rules! declare_plugin {
433 ($plugin_type:ty, $create_fn:expr) => {
434 #[no_mangle]
435 pub extern "C" fn _plugin_create() -> *mut dyn $crate::plugin::AgentPlugin {
436 let plugin: Box<dyn $crate::plugin::AgentPlugin> = Box::new($create_fn);
437 Box::into_raw(plugin)
438 }
439
440 #[no_mangle]
441 pub extern "C" fn _plugin_destroy(plugin: *mut dyn $crate::plugin::AgentPlugin) {
442 if !plugin.is_null() {
443 unsafe {
444 let _ = Box::from_raw(plugin);
445 }
446 }
447 }
448
449 #[no_mangle]
450 pub extern "C" fn _plugin_api_version() -> u32 {
451 $crate::hot_reload::PluginLoader::CURRENT_API_VERSION
452 }
453
454 #[no_mangle]
455 pub extern "C" fn _plugin_metadata() -> $crate::plugin::PluginMetadata {
456 let plugin: $plugin_type = $create_fn;
457 plugin.metadata().clone()
458 }
459 };
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[tokio::test]
467 async fn test_plugin_loader_new() {
468 let loader = PluginLoader::new();
469 assert_eq!(loader.api_version, PluginLoader::CURRENT_API_VERSION);
470 assert!(!loader.unsafe_mode);
471 }
472
473 #[tokio::test]
474 async fn test_search_paths() {
475 let mut loader = PluginLoader::new();
476 loader.add_search_path("/usr/lib/plugins");
477 loader.add_search_path("/opt/plugins");
478 assert_eq!(loader.search_paths.len(), 2);
479 }
480
481 #[test]
482 fn test_calculate_hash() {
483 let temp_dir = tempfile::tempdir().unwrap();
485 let file_path = temp_dir.path().join("test.txt");
486 std::fs::write(&file_path, b"test content").unwrap();
487
488 let hash1 = PluginLoader::calculate_hash(&file_path).unwrap();
489 let hash2 = PluginLoader::calculate_hash(&file_path).unwrap();
490
491 assert_eq!(hash1, hash2);
492 assert!(!hash1.is_empty());
493 }
494}