1use super::*;
10use crate::registry::PluginRegistry;
11use crate::sandbox::PluginSandbox;
12use crate::validator::PluginValidator;
13use std::path::PathBuf;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16use tokio::time::{timeout, Duration};
17
18use mockforge_plugin_core::{PluginHealth, PluginId, PluginInstance, PluginManifest};
20
21pub struct PluginLoader {
23 registry: Arc<RwLock<PluginRegistry>>,
25 validator: PluginValidator,
27 sandbox: PluginSandbox,
29 config: PluginLoaderConfig,
31 stats: RwLock<PluginLoadStats>,
33}
34
35unsafe impl Send for PluginLoader {}
37unsafe impl Sync for PluginLoader {}
38
39impl PluginLoader {
40 pub fn new(config: PluginLoaderConfig) -> Self {
42 Self {
43 registry: Arc::new(RwLock::new(PluginRegistry::new())),
44 validator: PluginValidator::new(config.clone()),
45 sandbox: PluginSandbox::new(config.clone()),
46 config,
47 stats: RwLock::new(PluginLoadStats::default()),
48 }
49 }
50
51 pub async fn load_all_plugins(&self) -> LoaderResult<PluginLoadStats> {
53 let mut stats = self.stats.write().await;
54 stats.start_loading();
55
56 let mut all_discoveries = Vec::new();
58 for dir in &self.config.plugin_dirs {
59 let discoveries = self.discover_plugins_in_directory(dir).await?;
60 all_discoveries.extend(discoveries);
61 }
62
63 stats.discovered = all_discoveries.len();
64
65 for discovery in all_discoveries {
67 if discovery.is_valid {
68 match self.load_plugin_from_discovery(&discovery).await {
69 Ok(_) => stats.record_success(),
70 Err(e) => {
71 tracing::warn!("Failed to load plugin {}: {}", discovery.plugin_id, e);
72 stats.record_failure();
73 }
74 }
75 } else {
76 tracing::debug!(
77 "Skipping invalid plugin {}: {}",
78 discovery.plugin_id,
79 discovery.first_error().unwrap_or("unknown error")
80 );
81 stats.record_skipped();
82 }
83 }
84
85 stats.finish_loading();
86 Ok(stats.clone())
87 }
88
89 pub async fn load_plugin(&self, plugin_id: &PluginId) -> LoaderResult<()> {
91 let discovery = self
93 .discover_plugin_by_id(plugin_id)
94 .await?
95 .ok_or_else(|| PluginLoaderError::not_found(plugin_id.clone()))?;
96
97 if !discovery.is_valid {
98 return Err(PluginLoaderError::validation(
99 discovery.first_error().unwrap_or("Plugin validation failed").to_string(),
100 ));
101 }
102
103 self.load_plugin_from_discovery(&discovery).await
104 }
105
106 pub async fn unload_plugin(&self, plugin_id: &PluginId) -> LoaderResult<()> {
108 let mut registry = self.registry.write().await;
109 registry.remove_plugin(plugin_id)?;
110
111 tracing::info!("Unloaded plugin: {}", plugin_id);
112 Ok(())
113 }
114
115 pub async fn get_plugin(&self, plugin_id: &PluginId) -> Option<PluginInstance> {
117 let registry = self.registry.read().await;
118 registry.get_plugin(plugin_id).cloned()
119 }
120
121 pub async fn list_plugins(&self) -> Vec<PluginId> {
123 let registry = self.registry.read().await;
124 registry.list_plugins()
125 }
126
127 pub async fn get_plugin_health(&self, plugin_id: &PluginId) -> LoaderResult<PluginHealth> {
129 let registry = self.registry.read().await;
130 registry.get_plugin_health(plugin_id)
131 }
132
133 pub async fn get_load_stats(&self) -> PluginLoadStats {
135 self.stats.read().await.clone()
136 }
137
138 pub async fn validate_plugin(&self, plugin_path: &Path) -> LoaderResult<PluginManifest> {
140 self.validator.validate_plugin_file(plugin_path).await
141 }
142
143 async fn discover_plugins_in_directory(
145 &self,
146 dir_path: &str,
147 ) -> LoaderResult<Vec<PluginDiscovery>> {
148 let expanded_path = shellexpand::tilde(dir_path);
149 let dir_path = PathBuf::from(expanded_path.as_ref());
150
151 if !dir_path.exists() {
152 tracing::debug!("Plugin directory does not exist: {}", dir_path.display());
153 return Ok(Vec::new());
154 }
155
156 if !dir_path.is_dir() {
157 return Err(PluginLoaderError::fs(format!(
158 "Plugin path is not a directory: {}",
159 dir_path.display()
160 )));
161 }
162
163 let mut discoveries = Vec::new();
164
165 let mut entries = match tokio::fs::read_dir(&dir_path).await {
167 Ok(entries) => entries,
168 Err(e) => {
169 return Err(PluginLoaderError::fs(format!(
170 "Failed to read plugin directory {}: {}",
171 dir_path.display(),
172 e
173 )));
174 }
175 };
176
177 while let Ok(Some(entry)) = entries.next_entry().await {
178 let path = entry.path();
179
180 if !path.is_dir() {
182 continue;
183 }
184
185 let manifest_path = path.join("plugin.yaml");
187 if !manifest_path.exists() {
188 continue;
189 }
190
191 match self.discover_single_plugin(&manifest_path).await {
193 Ok(discovery) => discoveries.push(discovery),
194 Err(e) => {
195 tracing::warn!("Failed to discover plugin at {}: {}", path.display(), e);
196 }
197 }
198 }
199
200 Ok(discoveries)
201 }
202
203 async fn discover_single_plugin(&self, manifest_path: &Path) -> LoaderResult<PluginDiscovery> {
205 let manifest = match PluginManifest::from_file(manifest_path) {
207 Ok(manifest) => manifest,
208 Err(e) => {
209 let plugin_id = PluginId::new("unknown".to_string());
210 let errors = vec![format!("Failed to load manifest: {}", e)];
211 return Ok(PluginDiscovery::failure(
212 plugin_id,
213 manifest_path.display().to_string(),
214 errors,
215 ));
216 }
217 };
218
219 let plugin_id = manifest.id().clone();
220
221 let validation_result = self.validator.validate_manifest(&manifest).await;
223
224 match validation_result {
225 Ok(_) => {
226 let discovery = PluginDiscovery::success(
227 plugin_id,
228 manifest,
229 manifest_path.parent().unwrap().display().to_string(),
230 );
231 Ok(discovery)
232 }
233 Err(errors) => {
234 let errors_str: Vec<String> = vec![errors.to_string()];
235 Ok(PluginDiscovery::failure(
236 plugin_id,
237 manifest_path.display().to_string(),
238 errors_str,
239 ))
240 }
241 }
242 }
243
244 async fn discover_plugin_by_id(
246 &self,
247 plugin_id: &PluginId,
248 ) -> LoaderResult<Option<PluginDiscovery>> {
249 for dir in &self.config.plugin_dirs {
250 let discoveries = self.discover_plugins_in_directory(dir).await?;
251 if let Some(discovery) = discoveries.into_iter().find(|d| &d.plugin_id == plugin_id) {
252 return Ok(Some(discovery));
253 }
254 }
255 Ok(None)
256 }
257
258 async fn load_plugin_from_discovery(&self, discovery: &PluginDiscovery) -> LoaderResult<()> {
260 let plugin_id = &discovery.plugin_id;
261
262 {
264 let registry = self.registry.read().await;
265 if registry.has_plugin(plugin_id) {
266 return Err(PluginLoaderError::already_loaded(plugin_id.clone()));
267 }
268 }
269
270 self.validator.validate_capabilities(&discovery.manifest.capabilities)?;
272
273 let plugin_path = self.find_plugin_wasm_file(&discovery.path).await?.ok_or_else(|| {
275 PluginLoaderError::load(format!("No WebAssembly file found for plugin {}", plugin_id))
276 })?;
277
278 self.validator.validate_wasm_file(&plugin_path).await?;
280
281 let load_context = PluginLoadContext::new(
283 plugin_id.clone(),
284 discovery.manifest.clone(),
285 plugin_path.display().to_string(),
286 self.config.clone(),
287 );
288
289 let load_timeout = Duration::from_secs(self.config.load_timeout_secs);
291 let plugin_instance = timeout(load_timeout, self.load_plugin_instance(&load_context))
292 .await
293 .map_err(|_| {
294 PluginLoaderError::load(format!(
295 "Plugin loading timed out after {} seconds",
296 self.config.load_timeout_secs
297 ))
298 })??;
299
300 let mut registry = self.registry.write().await;
302 registry.add_plugin(plugin_instance)?;
303
304 tracing::info!("Successfully loaded plugin: {}", plugin_id);
305 Ok(())
306 }
307
308 async fn load_plugin_instance(
310 &self,
311 context: &PluginLoadContext,
312 ) -> LoaderResult<PluginInstance> {
313 self.sandbox.create_plugin_instance(context).await
315 }
316
317 async fn find_plugin_wasm_file(&self, plugin_dir: &str) -> LoaderResult<Option<PathBuf>> {
319 let plugin_path = PathBuf::from(plugin_dir);
320
321 let mut entries = match tokio::fs::read_dir(&plugin_path).await {
323 Ok(entries) => entries,
324 Err(e) => {
325 return Err(PluginLoaderError::fs(format!(
326 "Failed to read plugin directory {}: {}",
327 plugin_path.display(),
328 e
329 )));
330 }
331 };
332
333 while let Ok(Some(entry)) = entries.next_entry().await {
334 let path = entry.path();
335 if let Some(extension) = path.extension() {
336 if extension == "wasm" {
337 return Ok(Some(path));
338 }
339 }
340 }
341
342 Ok(None)
343 }
344
345 pub async fn reload_all_plugins(&self) -> LoaderResult<PluginLoadStats> {
347 let loaded_plugins = self.list_plugins().await;
349
350 for plugin_id in &loaded_plugins {
352 if let Err(e) = self.unload_plugin(plugin_id).await {
353 tracing::warn!("Failed to unload plugin {} during reload: {}", plugin_id, e);
354 }
355 }
356
357 self.load_all_plugins().await
359 }
360
361 pub async fn reload_plugin(&self, plugin_id: &PluginId) -> LoaderResult<()> {
363 self.unload_plugin(plugin_id).await?;
365
366 self.load_plugin(plugin_id).await
368 }
369
370 pub fn registry(&self) -> Arc<RwLock<PluginRegistry>> {
372 Arc::clone(&self.registry)
373 }
374
375 pub fn validator(&self) -> &PluginValidator {
377 &self.validator
378 }
379
380 pub fn sandbox(&self) -> &PluginSandbox {
382 &self.sandbox
383 }
384}
385
386impl Default for PluginLoader {
387 fn default() -> Self {
388 Self::new(PluginLoaderConfig::default())
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[tokio::test]
397 async fn test_plugin_loader_creation() {
398 let config = PluginLoaderConfig::default();
399 let loader = PluginLoader::new(config);
400
401 let stats = loader.get_load_stats().await;
402 assert_eq!(stats.discovered, 0);
403 assert_eq!(stats.loaded, 0);
404 }
405
406 #[tokio::test]
407 async fn test_load_stats() {
408 let mut stats = PluginLoadStats::default();
409
410 stats.start_loading();
411 assert!(stats.start_time.is_some());
412
413 stats.record_success();
414 stats.record_failure();
415 stats.record_skipped();
416
417 stats.finish_loading();
418 assert!(stats.end_time.is_some());
419
420 assert_eq!(stats.loaded, 1);
421 assert_eq!(stats.failed, 1);
422 assert_eq!(stats.skipped, 1);
423 assert_eq!(stats.discovered, 3);
424 assert_eq!(stats.success_rate(), 33.33333333333333);
425 }
426
427 #[tokio::test]
428 async fn test_plugin_discovery_success() {
429 let plugin_id = PluginId::new("test-plugin");
430 let manifest = PluginManifest::new(PluginInfo::new(
431 plugin_id.clone(),
432 PluginVersion::new(1, 0, 0),
433 "Test Plugin",
434 "A test plugin",
435 PluginAuthor::new("Test Author"),
436 ));
437
438 let discovery = PluginDiscovery::success(plugin_id, manifest, "/tmp/test".to_string());
439 assert!(discovery.is_success());
440 assert!(discovery.errors.is_empty());
441 }
442
443 #[tokio::test]
444 async fn test_plugin_discovery_failure() {
445 let plugin_id = PluginId::new("test-plugin");
446 let errors = vec!["Validation failed".to_string()];
447
448 let discovery =
449 PluginDiscovery::failure(plugin_id, "/tmp/test".to_string(), errors.clone());
450 assert!(!discovery.is_success());
451 assert_eq!(discovery.errors, errors);
452 assert_eq!(discovery.first_error(), Some("Validation failed"));
453 }
454}