1use super::*;
10use crate::registry::PluginRegistry;
11use crate::sandbox::PluginSandbox;
12use crate::validator::PluginValidator;
13use std::path::PathBuf;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16use tokio::time::{timeout, Duration};
17
18use mockforge_plugin_core::{PluginHealth, PluginId, PluginInstance, PluginManifest};
20
21pub struct PluginLoader {
23 registry: Arc<RwLock<PluginRegistry>>,
25 validator: PluginValidator,
27 sandbox: PluginSandbox,
29 config: PluginLoaderConfig,
31 stats: RwLock<PluginLoadStats>,
33}
34
35unsafe impl Send for PluginLoader {}
39unsafe impl Sync for PluginLoader {}
40
41impl PluginLoader {
42 pub fn new(config: PluginLoaderConfig) -> Self {
44 Self {
45 registry: Arc::new(RwLock::new(PluginRegistry::new())),
46 validator: PluginValidator::new(config.clone()),
47 sandbox: PluginSandbox::new(config.clone()),
48 config,
49 stats: RwLock::new(PluginLoadStats::default()),
50 }
51 }
52
53 pub async fn load_all_plugins(&self) -> LoaderResult<PluginLoadStats> {
55 let mut stats = self.stats.write().await;
56 stats.start_loading();
57
58 let mut all_discoveries = Vec::new();
60 for dir in &self.config.plugin_dirs {
61 let discoveries = self.discover_plugins_in_directory(dir).await?;
62 all_discoveries.extend(discoveries);
63 }
64
65 stats.discovered = all_discoveries.len();
66
67 for discovery in all_discoveries {
69 if discovery.is_valid {
70 match self.load_plugin_from_discovery(&discovery).await {
71 Ok(_) => stats.record_success(),
72 Err(e) => {
73 tracing::warn!("Failed to load plugin {}: {}", discovery.plugin_id, e);
74 stats.record_failure();
75 }
76 }
77 } else {
78 tracing::debug!(
79 "Skipping invalid plugin {}: {}",
80 discovery.plugin_id,
81 discovery.first_error().unwrap_or("unknown error")
82 );
83 stats.record_skipped();
84 }
85 }
86
87 stats.finish_loading();
88 Ok(stats.clone())
89 }
90
91 pub async fn load_plugin(&self, plugin_id: &PluginId) -> LoaderResult<()> {
93 let discovery = self
95 .discover_plugin_by_id(plugin_id)
96 .await?
97 .ok_or_else(|| PluginLoaderError::not_found(plugin_id.clone()))?;
98
99 if !discovery.is_valid {
100 return Err(PluginLoaderError::validation(
101 discovery.first_error().unwrap_or("Plugin validation failed").to_string(),
102 ));
103 }
104
105 self.load_plugin_from_discovery(&discovery).await
106 }
107
108 pub async fn unload_plugin(&self, plugin_id: &PluginId) -> LoaderResult<()> {
110 let mut registry = self.registry.write().await;
111 registry.remove_plugin(plugin_id)?;
112
113 tracing::info!("Unloaded plugin: {}", plugin_id);
114 Ok(())
115 }
116
117 pub async fn get_plugin(&self, plugin_id: &PluginId) -> Option<PluginInstance> {
119 let registry = self.registry.read().await;
120 registry.get_plugin(plugin_id).cloned()
121 }
122
123 pub async fn list_plugins(&self) -> Vec<PluginId> {
125 let registry = self.registry.read().await;
126 registry.list_plugins()
127 }
128
129 pub async fn get_plugin_health(&self, plugin_id: &PluginId) -> LoaderResult<PluginHealth> {
131 let registry = self.registry.read().await;
132 registry.get_plugin_health(plugin_id)
133 }
134
135 pub async fn get_load_stats(&self) -> PluginLoadStats {
137 self.stats.read().await.clone()
138 }
139
140 pub async fn validate_plugin(&self, plugin_path: &Path) -> LoaderResult<PluginManifest> {
142 self.validator.validate_plugin_file(plugin_path).await
143 }
144
145 async fn discover_plugins_in_directory(
147 &self,
148 dir_path: &str,
149 ) -> LoaderResult<Vec<PluginDiscovery>> {
150 let expanded_path = shellexpand::tilde(dir_path);
151 let dir_path = PathBuf::from(expanded_path.as_ref());
152
153 if !dir_path.exists() {
154 tracing::debug!("Plugin directory does not exist: {}", dir_path.display());
155 return Ok(Vec::new());
156 }
157
158 if !dir_path.is_dir() {
159 return Err(PluginLoaderError::fs(format!(
160 "Plugin path is not a directory: {}",
161 dir_path.display()
162 )));
163 }
164
165 let mut discoveries = Vec::new();
166
167 let mut entries = match tokio::fs::read_dir(&dir_path).await {
169 Ok(entries) => entries,
170 Err(e) => {
171 return Err(PluginLoaderError::fs(format!(
172 "Failed to read plugin directory {}: {}",
173 dir_path.display(),
174 e
175 )));
176 }
177 };
178
179 while let Ok(Some(entry)) = entries.next_entry().await {
180 let path = entry.path();
181
182 if !path.is_dir() {
184 continue;
185 }
186
187 let manifest_path = path.join("plugin.yaml");
189 if !manifest_path.exists() {
190 continue;
191 }
192
193 match self.discover_single_plugin(&manifest_path).await {
195 Ok(discovery) => discoveries.push(discovery),
196 Err(e) => {
197 tracing::warn!("Failed to discover plugin at {}: {}", path.display(), e);
198 }
199 }
200 }
201
202 Ok(discoveries)
203 }
204
205 async fn discover_single_plugin(&self, manifest_path: &Path) -> LoaderResult<PluginDiscovery> {
207 let manifest = match PluginManifest::from_file(manifest_path) {
209 Ok(manifest) => manifest,
210 Err(e) => {
211 let plugin_id = PluginId::new("unknown".to_string());
212 let errors = vec![format!("Failed to load manifest: {}", e)];
213 return Ok(PluginDiscovery::failure(
214 plugin_id,
215 manifest_path.display().to_string(),
216 errors,
217 ));
218 }
219 };
220
221 let plugin_id = manifest.id().clone();
222
223 let validation_result = self.validator.validate_manifest(&manifest).await;
225
226 match validation_result {
227 Ok(_) => {
228 let discovery = PluginDiscovery::success(
229 plugin_id,
230 manifest,
231 manifest_path.parent().unwrap().display().to_string(),
232 );
233 Ok(discovery)
234 }
235 Err(errors) => {
236 let errors_str: Vec<String> = vec![errors.to_string()];
237 Ok(PluginDiscovery::failure(
238 plugin_id,
239 manifest_path.display().to_string(),
240 errors_str,
241 ))
242 }
243 }
244 }
245
246 async fn discover_plugin_by_id(
248 &self,
249 plugin_id: &PluginId,
250 ) -> LoaderResult<Option<PluginDiscovery>> {
251 for dir in &self.config.plugin_dirs {
252 let discoveries = self.discover_plugins_in_directory(dir).await?;
253 if let Some(discovery) = discoveries.into_iter().find(|d| &d.plugin_id == plugin_id) {
254 return Ok(Some(discovery));
255 }
256 }
257 Ok(None)
258 }
259
260 async fn load_plugin_from_discovery(&self, discovery: &PluginDiscovery) -> LoaderResult<()> {
262 let plugin_id = &discovery.plugin_id;
263
264 {
266 let registry = self.registry.read().await;
267 if registry.has_plugin(plugin_id) {
268 return Err(PluginLoaderError::already_loaded(plugin_id.clone()));
269 }
270 }
271
272 self.validator.validate_capabilities(&discovery.manifest.capabilities)?;
274
275 let plugin_path = self.find_plugin_wasm_file(&discovery.path).await?.ok_or_else(|| {
277 PluginLoaderError::load(format!("No WebAssembly file found for plugin {}", plugin_id))
278 })?;
279
280 self.validator.validate_wasm_file(&plugin_path).await?;
282
283 let load_context = PluginLoadContext::new(
285 plugin_id.clone(),
286 discovery.manifest.clone(),
287 plugin_path.display().to_string(),
288 self.config.clone(),
289 );
290
291 let load_timeout = Duration::from_secs(self.config.load_timeout_secs);
293 let plugin_instance = timeout(load_timeout, self.load_plugin_instance(&load_context))
294 .await
295 .map_err(|_| {
296 PluginLoaderError::load(format!(
297 "Plugin loading timed out after {} seconds",
298 self.config.load_timeout_secs
299 ))
300 })??;
301
302 let mut registry = self.registry.write().await;
304 registry.add_plugin(plugin_instance)?;
305
306 tracing::info!("Successfully loaded plugin: {}", plugin_id);
307 Ok(())
308 }
309
310 async fn load_plugin_instance(
312 &self,
313 context: &PluginLoadContext,
314 ) -> LoaderResult<PluginInstance> {
315 self.sandbox.create_plugin_instance(context).await
317 }
318
319 async fn find_plugin_wasm_file(&self, plugin_dir: &str) -> LoaderResult<Option<PathBuf>> {
321 let plugin_path = PathBuf::from(plugin_dir);
322
323 let mut entries = match tokio::fs::read_dir(&plugin_path).await {
325 Ok(entries) => entries,
326 Err(e) => {
327 return Err(PluginLoaderError::fs(format!(
328 "Failed to read plugin directory {}: {}",
329 plugin_path.display(),
330 e
331 )));
332 }
333 };
334
335 while let Ok(Some(entry)) = entries.next_entry().await {
336 let path = entry.path();
337 if let Some(extension) = path.extension() {
338 if extension == "wasm" {
339 return Ok(Some(path));
340 }
341 }
342 }
343
344 Ok(None)
345 }
346
347 pub async fn reload_all_plugins(&self) -> LoaderResult<PluginLoadStats> {
349 let loaded_plugins = self.list_plugins().await;
351
352 for plugin_id in &loaded_plugins {
354 if let Err(e) = self.unload_plugin(plugin_id).await {
355 tracing::warn!("Failed to unload plugin {} during reload: {}", plugin_id, e);
356 }
357 }
358
359 self.load_all_plugins().await
361 }
362
363 pub async fn reload_plugin(&self, plugin_id: &PluginId) -> LoaderResult<()> {
365 self.unload_plugin(plugin_id).await?;
367
368 self.load_plugin(plugin_id).await
370 }
371
372 pub fn registry(&self) -> Arc<RwLock<PluginRegistry>> {
374 Arc::clone(&self.registry)
375 }
376
377 pub fn validator(&self) -> &PluginValidator {
379 &self.validator
380 }
381
382 pub fn sandbox(&self) -> &PluginSandbox {
384 &self.sandbox
385 }
386}
387
388impl Default for PluginLoader {
389 fn default() -> Self {
390 Self::new(PluginLoaderConfig::default())
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 #[tokio::test]
399 async fn test_plugin_loader_creation() {
400 let config = PluginLoaderConfig::default();
401 let loader = PluginLoader::new(config);
402
403 let stats = loader.get_load_stats().await;
404 assert_eq!(stats.discovered, 0);
405 assert_eq!(stats.loaded, 0);
406 }
407
408 #[tokio::test]
409 async fn test_load_stats() {
410 let mut stats = PluginLoadStats::default();
411
412 stats.start_loading();
413 assert!(stats.start_time.is_some());
414
415 stats.record_success();
416 stats.record_failure();
417 stats.record_skipped();
418
419 stats.finish_loading();
420 assert!(stats.end_time.is_some());
421
422 assert_eq!(stats.loaded, 1);
423 assert_eq!(stats.failed, 1);
424 assert_eq!(stats.skipped, 1);
425 assert_eq!(stats.discovered, 3);
426 assert_eq!(stats.success_rate(), 33.33333333333333);
427 }
428
429 #[tokio::test]
430 async fn test_plugin_discovery_success() {
431 let plugin_id = PluginId::new("test-plugin");
432 let manifest = PluginManifest::new(PluginInfo::new(
433 plugin_id.clone(),
434 PluginVersion::new(1, 0, 0),
435 "Test Plugin",
436 "A test plugin",
437 PluginAuthor::new("Test Author"),
438 ));
439
440 let discovery = PluginDiscovery::success(plugin_id, manifest, "/tmp/test".to_string());
441 assert!(discovery.is_success());
442 assert!(discovery.errors.is_empty());
443 }
444
445 #[tokio::test]
446 async fn test_plugin_discovery_failure() {
447 let plugin_id = PluginId::new("test-plugin");
448 let errors = vec!["Validation failed".to_string()];
449
450 let discovery =
451 PluginDiscovery::failure(plugin_id, "/tmp/test".to_string(), errors.clone());
452 assert!(!discovery.is_success());
453 assert_eq!(discovery.errors, errors);
454 assert_eq!(discovery.first_error(), Some("Validation failed"));
455 }
456}