Skip to main content

mockforge_plugin_loader/
loader.rs

1//! Plugin loader implementation
2//!
3//! This module provides the main PluginLoader that handles:
4//! - Plugin discovery and validation
5//! - Secure plugin loading with sandboxing
6//! - Plugin lifecycle management
7//! - Resource monitoring and cleanup
8
9use super::*;
10use crate::registry::PluginRegistry;
11use crate::sandbox::PluginSandbox;
12use crate::validator::PluginValidator;
13use std::path::PathBuf;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16use tokio::time::{timeout, Duration};
17
18// Import types from plugin core
19use mockforge_plugin_core::{PluginHealth, PluginId, PluginInstance, PluginManifest};
20
21/// Main plugin loader
22pub struct PluginLoader {
23    /// Plugin registry
24    registry: Arc<RwLock<PluginRegistry>>,
25    /// Plugin validator
26    validator: PluginValidator,
27    /// Plugin sandbox
28    sandbox: PluginSandbox,
29    /// Loader configuration
30    config: PluginLoaderConfig,
31    /// Loading statistics
32    stats: RwLock<PluginLoadStats>,
33}
34
35// SAFETY: All PluginLoader fields are individually Send + Sync (Arc, RwLock wrappers)
36// except for PluginRegistry which requires manual Send + Sync due to wasmtime types.
37// The registry is wrapped in Arc<RwLock<...>> ensuring synchronized access.
38unsafe impl Send for PluginLoader {}
39unsafe impl Sync for PluginLoader {}
40
41impl PluginLoader {
42    /// Create a new plugin loader
43    pub fn new(config: PluginLoaderConfig) -> Self {
44        Self {
45            registry: Arc::new(RwLock::new(PluginRegistry::new())),
46            validator: PluginValidator::new(config.clone()),
47            sandbox: PluginSandbox::new(config.clone()),
48            config,
49            stats: RwLock::new(PluginLoadStats::default()),
50        }
51    }
52
53    /// Load all plugins from configured directories
54    pub async fn load_all_plugins(&self) -> LoaderResult<PluginLoadStats> {
55        let mut stats = self.stats.write().await;
56        stats.start_loading();
57
58        // Discover plugins from all configured directories
59        let mut all_discoveries = Vec::new();
60        for dir in &self.config.plugin_dirs {
61            let discoveries = self.discover_plugins_in_directory(dir).await?;
62            all_discoveries.extend(discoveries);
63        }
64
65        stats.discovered = all_discoveries.len();
66
67        // Load valid plugins
68        for discovery in all_discoveries {
69            if discovery.is_valid {
70                match self.load_plugin_from_discovery(&discovery).await {
71                    Ok(_) => stats.record_success(),
72                    Err(e) => {
73                        tracing::warn!("Failed to load plugin {}: {}", discovery.plugin_id, e);
74                        stats.record_failure();
75                    }
76                }
77            } else {
78                tracing::debug!(
79                    "Skipping invalid plugin {}: {}",
80                    discovery.plugin_id,
81                    discovery.first_error().unwrap_or("unknown error")
82                );
83                stats.record_skipped();
84            }
85        }
86
87        stats.finish_loading();
88        Ok(stats.clone())
89    }
90
91    /// Load a specific plugin by ID
92    pub async fn load_plugin(&self, plugin_id: &PluginId) -> LoaderResult<()> {
93        // Find plugin in discovery paths
94        let discovery = self
95            .discover_plugin_by_id(plugin_id)
96            .await?
97            .ok_or_else(|| PluginLoaderError::not_found(plugin_id.clone()))?;
98
99        if !discovery.is_valid {
100            return Err(PluginLoaderError::validation(
101                discovery.first_error().unwrap_or("Plugin validation failed").to_string(),
102            ));
103        }
104
105        self.load_plugin_from_discovery(&discovery).await
106    }
107
108    /// Unload a plugin
109    pub async fn unload_plugin(&self, plugin_id: &PluginId) -> LoaderResult<()> {
110        let mut registry = self.registry.write().await;
111        registry.remove_plugin(plugin_id)?;
112
113        tracing::info!("Unloaded plugin: {}", plugin_id);
114        Ok(())
115    }
116
117    /// Get loaded plugin by ID
118    pub async fn get_plugin(&self, plugin_id: &PluginId) -> Option<PluginInstance> {
119        let registry = self.registry.read().await;
120        registry.get_plugin(plugin_id).cloned()
121    }
122
123    /// List all loaded plugins
124    pub async fn list_plugins(&self) -> Vec<PluginId> {
125        let registry = self.registry.read().await;
126        registry.list_plugins()
127    }
128
129    /// Get plugin health status
130    pub async fn get_plugin_health(&self, plugin_id: &PluginId) -> LoaderResult<PluginHealth> {
131        let registry = self.registry.read().await;
132        registry.get_plugin_health(plugin_id)
133    }
134
135    /// Get loading statistics
136    pub async fn get_load_stats(&self) -> PluginLoadStats {
137        self.stats.read().await.clone()
138    }
139
140    /// Validate plugin without loading
141    pub async fn validate_plugin(&self, plugin_path: &Path) -> LoaderResult<PluginManifest> {
142        self.validator.validate_plugin_file(plugin_path).await
143    }
144
145    /// Discover plugins in a directory
146    async fn discover_plugins_in_directory(
147        &self,
148        dir_path: &str,
149    ) -> LoaderResult<Vec<PluginDiscovery>> {
150        let expanded_path = shellexpand::tilde(dir_path);
151        let dir_path = PathBuf::from(expanded_path.as_ref());
152
153        if !dir_path.exists() {
154            tracing::debug!("Plugin directory does not exist: {}", dir_path.display());
155            return Ok(Vec::new());
156        }
157
158        if !dir_path.is_dir() {
159            return Err(PluginLoaderError::fs(format!(
160                "Plugin path is not a directory: {}",
161                dir_path.display()
162            )));
163        }
164
165        let mut discoveries = Vec::new();
166
167        // Read directory entries
168        let mut entries = match tokio::fs::read_dir(&dir_path).await {
169            Ok(entries) => entries,
170            Err(e) => {
171                return Err(PluginLoaderError::fs(format!(
172                    "Failed to read plugin directory {}: {}",
173                    dir_path.display(),
174                    e
175                )));
176            }
177        };
178
179        while let Ok(Some(entry)) = entries.next_entry().await {
180            let path = entry.path();
181
182            // Skip non-directories
183            if !path.is_dir() {
184                continue;
185            }
186
187            // Look for plugin.yaml in the directory
188            let manifest_path = path.join("plugin.yaml");
189            if !manifest_path.exists() {
190                continue;
191            }
192
193            // Try to discover plugin
194            match self.discover_single_plugin(&manifest_path).await {
195                Ok(discovery) => discoveries.push(discovery),
196                Err(e) => {
197                    tracing::warn!("Failed to discover plugin at {}: {}", path.display(), e);
198                }
199            }
200        }
201
202        Ok(discoveries)
203    }
204
205    /// Discover a single plugin from manifest file
206    async fn discover_single_plugin(&self, manifest_path: &Path) -> LoaderResult<PluginDiscovery> {
207        // Load and validate manifest
208        let manifest = match PluginManifest::from_file(manifest_path) {
209            Ok(manifest) => manifest,
210            Err(e) => {
211                let plugin_id = PluginId::new("unknown".to_string());
212                let errors = vec![format!("Failed to load manifest: {}", e)];
213                return Ok(PluginDiscovery::failure(
214                    plugin_id,
215                    manifest_path.display().to_string(),
216                    errors,
217                ));
218            }
219        };
220
221        let plugin_id = manifest.id().clone();
222
223        // Validate manifest
224        let validation_result = self.validator.validate_manifest(&manifest).await;
225
226        match validation_result {
227            Ok(_) => {
228                let discovery = PluginDiscovery::success(
229                    plugin_id,
230                    manifest,
231                    manifest_path.parent().unwrap().display().to_string(),
232                );
233                Ok(discovery)
234            }
235            Err(errors) => {
236                let errors_str: Vec<String> = vec![errors.to_string()];
237                Ok(PluginDiscovery::failure(
238                    plugin_id,
239                    manifest_path.display().to_string(),
240                    errors_str,
241                ))
242            }
243        }
244    }
245
246    /// Discover plugin by ID
247    async fn discover_plugin_by_id(
248        &self,
249        plugin_id: &PluginId,
250    ) -> LoaderResult<Option<PluginDiscovery>> {
251        for dir in &self.config.plugin_dirs {
252            let discoveries = self.discover_plugins_in_directory(dir).await?;
253            if let Some(discovery) = discoveries.into_iter().find(|d| &d.plugin_id == plugin_id) {
254                return Ok(Some(discovery));
255            }
256        }
257        Ok(None)
258    }
259
260    /// Load plugin from discovery result
261    async fn load_plugin_from_discovery(&self, discovery: &PluginDiscovery) -> LoaderResult<()> {
262        let plugin_id = &discovery.plugin_id;
263
264        // Check if already loaded
265        {
266            let registry = self.registry.read().await;
267            if registry.has_plugin(plugin_id) {
268                return Err(PluginLoaderError::already_loaded(plugin_id.clone()));
269            }
270        }
271
272        // Validate capabilities
273        self.validator.validate_capabilities(&discovery.manifest.capabilities)?;
274
275        // Find WASM file
276        let plugin_path = self.find_plugin_wasm_file(&discovery.path).await?.ok_or_else(|| {
277            PluginLoaderError::load(format!("No WebAssembly file found for plugin {}", plugin_id))
278        })?;
279
280        // Validate WASM file
281        self.validator.validate_wasm_file(&plugin_path).await?;
282
283        // Create load context
284        let load_context = PluginLoadContext::new(
285            plugin_id.clone(),
286            discovery.manifest.clone(),
287            plugin_path.display().to_string(),
288            self.config.clone(),
289        );
290
291        // Load plugin with timeout
292        let load_timeout = Duration::from_secs(self.config.load_timeout_secs);
293        let plugin_instance = timeout(load_timeout, self.load_plugin_instance(&load_context))
294            .await
295            .map_err(|_| {
296                PluginLoaderError::load(format!(
297                    "Plugin loading timed out after {} seconds",
298                    self.config.load_timeout_secs
299                ))
300            })??;
301
302        // Register plugin
303        let mut registry = self.registry.write().await;
304        registry.add_plugin(plugin_instance)?;
305
306        tracing::info!("Successfully loaded plugin: {}", plugin_id);
307        Ok(())
308    }
309
310    /// Load plugin instance
311    async fn load_plugin_instance(
312        &self,
313        context: &PluginLoadContext,
314    ) -> LoaderResult<PluginInstance> {
315        // Create plugin instance through sandbox
316        self.sandbox.create_plugin_instance(context).await
317    }
318
319    /// Find plugin WASM file in plugin directory
320    async fn find_plugin_wasm_file(&self, plugin_dir: &str) -> LoaderResult<Option<PathBuf>> {
321        let plugin_path = PathBuf::from(plugin_dir);
322
323        // Look for .wasm files in the plugin directory
324        let mut entries = match tokio::fs::read_dir(&plugin_path).await {
325            Ok(entries) => entries,
326            Err(e) => {
327                return Err(PluginLoaderError::fs(format!(
328                    "Failed to read plugin directory {}: {}",
329                    plugin_path.display(),
330                    e
331                )));
332            }
333        };
334
335        while let Ok(Some(entry)) = entries.next_entry().await {
336            let path = entry.path();
337            if let Some(extension) = path.extension() {
338                if extension == "wasm" {
339                    return Ok(Some(path));
340                }
341            }
342        }
343
344        Ok(None)
345    }
346
347    /// Reload all plugins
348    pub async fn reload_all_plugins(&self) -> LoaderResult<PluginLoadStats> {
349        // Get currently loaded plugins
350        let loaded_plugins = self.list_plugins().await;
351
352        // Unload all plugins
353        for plugin_id in &loaded_plugins {
354            if let Err(e) = self.unload_plugin(plugin_id).await {
355                tracing::warn!("Failed to unload plugin {} during reload: {}", plugin_id, e);
356            }
357        }
358
359        // Load all plugins again
360        self.load_all_plugins().await
361    }
362
363    /// Reload specific plugin
364    pub async fn reload_plugin(&self, plugin_id: &PluginId) -> LoaderResult<()> {
365        // Unload plugin
366        self.unload_plugin(plugin_id).await?;
367
368        // Load plugin again
369        self.load_plugin(plugin_id).await
370    }
371
372    /// Get registry reference (for advanced operations)
373    pub fn registry(&self) -> Arc<RwLock<PluginRegistry>> {
374        Arc::clone(&self.registry)
375    }
376
377    /// Get validator reference (for advanced operations)
378    pub fn validator(&self) -> &PluginValidator {
379        &self.validator
380    }
381
382    /// Get sandbox reference (for advanced operations)
383    pub fn sandbox(&self) -> &PluginSandbox {
384        &self.sandbox
385    }
386}
387
388impl Default for PluginLoader {
389    fn default() -> Self {
390        Self::new(PluginLoaderConfig::default())
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    #[tokio::test]
399    async fn test_plugin_loader_creation() {
400        let config = PluginLoaderConfig::default();
401        let loader = PluginLoader::new(config);
402
403        let stats = loader.get_load_stats().await;
404        assert_eq!(stats.discovered, 0);
405        assert_eq!(stats.loaded, 0);
406    }
407
408    #[tokio::test]
409    async fn test_load_stats() {
410        let mut stats = PluginLoadStats::default();
411
412        stats.start_loading();
413        assert!(stats.start_time.is_some());
414
415        stats.record_success();
416        stats.record_failure();
417        stats.record_skipped();
418
419        stats.finish_loading();
420        assert!(stats.end_time.is_some());
421
422        assert_eq!(stats.loaded, 1);
423        assert_eq!(stats.failed, 1);
424        assert_eq!(stats.skipped, 1);
425        assert_eq!(stats.discovered, 3);
426        assert_eq!(stats.success_rate(), 33.33333333333333);
427    }
428
429    #[tokio::test]
430    async fn test_plugin_discovery_success() {
431        let plugin_id = PluginId::new("test-plugin");
432        let manifest = PluginManifest::new(PluginInfo::new(
433            plugin_id.clone(),
434            PluginVersion::new(1, 0, 0),
435            "Test Plugin",
436            "A test plugin",
437            PluginAuthor::new("Test Author"),
438        ));
439
440        let discovery = PluginDiscovery::success(plugin_id, manifest, "/tmp/test".to_string());
441        assert!(discovery.is_success());
442        assert!(discovery.errors.is_empty());
443    }
444
445    #[tokio::test]
446    async fn test_plugin_discovery_failure() {
447        let plugin_id = PluginId::new("test-plugin");
448        let errors = vec!["Validation failed".to_string()];
449
450        let discovery =
451            PluginDiscovery::failure(plugin_id, "/tmp/test".to_string(), errors.clone());
452        assert!(!discovery.is_success());
453        assert_eq!(discovery.errors, errors);
454        assert_eq!(discovery.first_error(), Some("Validation failed"));
455    }
456}