Skip to main content

pro_plugin/
host.rs

1//! Plugin host for loading and executing Wasm plugins
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use extism::{Manifest, Plugin, Wasm};
7
8use crate::manifest::{PluginConfig, PluginManifest, PluginPermissions};
9use crate::{Hook, HookContext, HookResult, PluginError, PluginResult};
10
11/// Plugin host that manages Wasm plugins
12pub struct PluginHost {
13    /// Loaded plugins
14    plugins: Vec<LoadedPlugin>,
15    /// Plugin directory
16    plugin_dir: PathBuf,
17    /// Default permissions for plugins
18    default_permissions: PluginPermissions,
19}
20
21/// A loaded plugin
22pub struct LoadedPlugin {
23    /// Plugin name
24    pub name: String,
25    /// Plugin manifest
26    pub manifest: PluginManifest,
27    /// Plugin config (from pyproject.toml)
28    pub config: Option<PluginConfig>,
29    /// Extism plugin instance
30    instance: Plugin,
31    /// Whether the plugin is enabled
32    pub enabled: bool,
33}
34
35impl LoadedPlugin {
36    /// Check if this plugin implements a hook
37    pub fn has_hook(&self, hook: Hook) -> bool {
38        self.manifest.has_hook(hook)
39    }
40
41    /// Call a hook function
42    pub fn call_hook(&mut self, hook: Hook, context: &HookContext) -> PluginResult<HookResult> {
43        let func_name = hook.function_name();
44
45        // Check if plugin implements this hook
46        if !self.has_hook(hook) {
47            return Ok(HookResult::ok());
48        }
49
50        // Serialize context
51        let input = context.to_bytes();
52
53        // Call the Wasm function
54        let output = self
55            .instance
56            .call::<&[u8], Vec<u8>>(func_name, &input)
57            .map_err(|e| PluginError::ExecutionError(e.to_string()))?;
58
59        // Deserialize result
60        let result = HookResult::from_bytes(&output)?;
61        Ok(result)
62    }
63}
64
65impl PluginHost {
66    /// Create a new plugin host
67    pub fn new(plugin_dir: impl Into<PathBuf>) -> Self {
68        Self {
69            plugins: vec![],
70            plugin_dir: plugin_dir.into(),
71            default_permissions: PluginPermissions::read_only(),
72        }
73    }
74
75    /// Create with default plugin directory (~/.rx/plugins)
76    pub fn with_default_dir() -> Self {
77        let plugin_dir = dirs::home_dir()
78            .map(|h| h.join(".rx").join("plugins"))
79            .unwrap_or_else(|| PathBuf::from(".rx/plugins"));
80
81        Self::new(plugin_dir)
82    }
83
84    /// Set default permissions for plugins
85    pub fn set_default_permissions(&mut self, permissions: PluginPermissions) {
86        self.default_permissions = permissions;
87    }
88
89    /// Get the plugin directory
90    pub fn plugin_dir(&self) -> &Path {
91        &self.plugin_dir
92    }
93
94    /// Ensure plugin directory exists
95    pub fn ensure_plugin_dir(&self) -> PluginResult<()> {
96        std::fs::create_dir_all(&self.plugin_dir).map_err(|e| {
97            PluginError::LoadError(format!("Failed to create plugin directory: {}", e))
98        })
99    }
100
101    /// Load a plugin from a Wasm file
102    pub fn load(&mut self, name: &str, wasm_path: &Path) -> PluginResult<()> {
103        self.load_with_config(name, wasm_path, None)
104    }
105
106    /// Load a plugin with specific configuration
107    pub fn load_with_config(
108        &mut self,
109        name: &str,
110        wasm_path: &Path,
111        config: Option<PluginConfig>,
112    ) -> PluginResult<()> {
113        if !wasm_path.exists() {
114            return Err(PluginError::NotFound {
115                path: wasm_path.display().to_string(),
116            });
117        }
118
119        tracing::info!("Loading plugin '{}' from {:?}", name, wasm_path);
120
121        // Read the Wasm file
122        let wasm_bytes = std::fs::read(wasm_path)
123            .map_err(|e| PluginError::LoadError(format!("Failed to read Wasm file: {}", e)))?;
124
125        // Try to extract manifest from custom section or use default
126        let manifest = self.extract_or_create_manifest(name, wasm_path, &wasm_bytes)?;
127
128        // Determine permissions
129        let permissions = config
130            .as_ref()
131            .and_then(|c| c.permissions.clone())
132            .unwrap_or_else(|| manifest.permissions.clone());
133
134        // Create Extism manifest with permissions
135        let extism_manifest = self.create_extism_manifest(&wasm_bytes, &permissions)?;
136
137        // Create the plugin instance
138        let instance = Plugin::new(&extism_manifest, [], true)
139            .map_err(|e| PluginError::LoadError(format!("Failed to create plugin: {}", e)))?;
140
141        let enabled = config.as_ref().map(|c| c.enabled).unwrap_or(true);
142
143        self.plugins.push(LoadedPlugin {
144            name: name.to_string(),
145            manifest,
146            config,
147            instance,
148            enabled,
149        });
150
151        tracing::info!("Successfully loaded plugin '{}'", name);
152        Ok(())
153    }
154
155    /// Extract manifest from Wasm or create a default one
156    fn extract_or_create_manifest(
157        &self,
158        name: &str,
159        wasm_path: &Path,
160        _wasm_bytes: &[u8],
161    ) -> PluginResult<PluginManifest> {
162        // First, try to load manifest from adjacent .toml file
163        let manifest_path = wasm_path.with_extension("toml");
164        if manifest_path.exists() {
165            let content = std::fs::read_to_string(&manifest_path).map_err(|e| {
166                PluginError::InvalidManifest(format!("Failed to read manifest: {}", e))
167            })?;
168            return PluginManifest::from_toml(&content).map_err(|e| {
169                PluginError::InvalidManifest(format!("Invalid manifest TOML: {}", e))
170            });
171        }
172
173        // TODO: Extract from Wasm custom section "rx_manifest"
174        // For now, create a default manifest
175        Ok(PluginManifest {
176            name: name.to_string(),
177            version: "0.0.0".to_string(),
178            description: String::new(),
179            author: None,
180            license: None,
181            homepage: None,
182            min_rx_version: None,
183            hooks: vec![
184                "pre_resolve".to_string(),
185                "post_resolve".to_string(),
186                "pre_build".to_string(),
187                "post_build".to_string(),
188                "pre_publish".to_string(),
189            ],
190            permissions: self.default_permissions.clone(),
191        })
192    }
193
194    /// Create Extism manifest with appropriate permissions
195    fn create_extism_manifest(
196        &self,
197        wasm_bytes: &[u8],
198        permissions: &PluginPermissions,
199    ) -> PluginResult<Manifest> {
200        let wasm = Wasm::data(wasm_bytes.to_vec());
201        let mut manifest = Manifest::new([wasm]);
202
203        // Configure allowed hosts for network access
204        if permissions.network && !permissions.allowed_hosts.is_empty() {
205            manifest = manifest.with_allowed_hosts(permissions.allowed_hosts.iter().cloned());
206        }
207
208        // Note: File system access is handled by host functions, not Extism directly
209        // We'll need to implement custom host functions for file I/O
210
211        Ok(manifest)
212    }
213
214    /// Load all plugins from a directory
215    pub fn load_from_dir(&mut self, dir: &Path) -> PluginResult<usize> {
216        if !dir.exists() {
217            return Ok(0);
218        }
219
220        let mut count = 0;
221        for entry in std::fs::read_dir(dir).map_err(|e| {
222            PluginError::LoadError(format!("Failed to read plugin directory: {}", e))
223        })? {
224            let entry = entry
225                .map_err(|e| PluginError::LoadError(format!("Failed to read entry: {}", e)))?;
226            let path = entry.path();
227
228            if path.extension().is_some_and(|ext| ext == "wasm") {
229                let name = path
230                    .file_stem()
231                    .and_then(|s| s.to_str())
232                    .unwrap_or("unknown");
233
234                match self.load(name, &path) {
235                    Ok(_) => count += 1,
236                    Err(e) => {
237                        tracing::warn!("Failed to load plugin {:?}: {}", path, e);
238                    }
239                }
240            }
241        }
242
243        Ok(count)
244    }
245
246    /// Load plugins from pyproject.toml configuration
247    pub fn load_from_config(
248        &mut self,
249        configs: &HashMap<String, PluginConfig>,
250    ) -> PluginResult<usize> {
251        let mut count = 0;
252
253        for (name, config) in configs {
254            if !config.enabled {
255                tracing::debug!("Skipping disabled plugin '{}'", name);
256                continue;
257            }
258
259            let path =
260                if config.source.starts_with("http://") || config.source.starts_with("https://") {
261                    // Download from URL
262                    match self.download_plugin(name, &config.source) {
263                        Ok(p) => p,
264                        Err(e) => {
265                            tracing::warn!("Failed to download plugin '{}': {}", name, e);
266                            continue;
267                        }
268                    }
269                } else {
270                    // Local path
271                    PathBuf::from(&config.source)
272                };
273
274            match self.load_with_config(name, &path, Some(config.clone())) {
275                Ok(_) => count += 1,
276                Err(e) => {
277                    tracing::warn!("Failed to load plugin '{}': {}", name, e);
278                }
279            }
280        }
281
282        Ok(count)
283    }
284
285    /// Download a plugin from URL
286    fn download_plugin(&self, name: &str, url: &str) -> PluginResult<PathBuf> {
287        self.ensure_plugin_dir()?;
288
289        let dest_path = self.plugin_dir.join(format!("{}.wasm", name));
290
291        // Use blocking reqwest for simplicity (this should be async in production)
292        let response = reqwest::blocking::get(url)
293            .map_err(|e| PluginError::LoadError(format!("Failed to download plugin: {}", e)))?;
294
295        if !response.status().is_success() {
296            return Err(PluginError::LoadError(format!(
297                "Failed to download plugin: HTTP {}",
298                response.status()
299            )));
300        }
301
302        let bytes = response
303            .bytes()
304            .map_err(|e| PluginError::LoadError(format!("Failed to read response: {}", e)))?;
305
306        std::fs::write(&dest_path, &bytes)
307            .map_err(|e| PluginError::LoadError(format!("Failed to save plugin: {}", e)))?;
308
309        Ok(dest_path)
310    }
311
312    /// Execute a hook on all enabled plugins that implement it
313    pub fn execute_hook(&mut self, hook: Hook, context: &HookContext) -> PluginResult<HookResult> {
314        tracing::debug!("Executing hook {:?}", hook);
315
316        let mut combined_result = HookResult::ok();
317
318        for plugin in &mut self.plugins {
319            if !plugin.enabled {
320                continue;
321            }
322
323            if !plugin.has_hook(hook) {
324                continue;
325            }
326
327            tracing::trace!("Running hook {:?} on plugin '{}'", hook, plugin.name);
328
329            match plugin.call_hook(hook, context) {
330                Ok(result) => {
331                    // Print any messages
332                    for msg in &result.messages {
333                        println!("[{}] {}", plugin.name, msg);
334                    }
335
336                    combined_result.merge(result);
337
338                    // Stop if plugin requested to halt
339                    if !combined_result.continue_operation {
340                        tracing::info!("Plugin '{}' stopped operation at {:?}", plugin.name, hook);
341                        break;
342                    }
343                }
344                Err(e) => {
345                    tracing::warn!("Plugin '{}' hook {:?} failed: {}", plugin.name, hook, e);
346                    // Continue with other plugins unless it's a critical error
347                }
348            }
349        }
350
351        Ok(combined_result)
352    }
353
354    /// Get the number of loaded plugins
355    pub fn plugin_count(&self) -> usize {
356        self.plugins.len()
357    }
358
359    /// Get the number of enabled plugins
360    pub fn enabled_count(&self) -> usize {
361        self.plugins.iter().filter(|p| p.enabled).count()
362    }
363
364    /// List all loaded plugins
365    pub fn list_plugins(&self) -> Vec<&LoadedPlugin> {
366        self.plugins.iter().collect()
367    }
368
369    /// Get a plugin by name
370    pub fn get_plugin(&self, name: &str) -> Option<&LoadedPlugin> {
371        self.plugins.iter().find(|p| p.name == name)
372    }
373
374    /// Remove a plugin by name
375    pub fn remove_plugin(&mut self, name: &str) -> bool {
376        let len_before = self.plugins.len();
377        self.plugins.retain(|p| p.name != name);
378        self.plugins.len() < len_before
379    }
380
381    /// Enable a plugin
382    pub fn enable_plugin(&mut self, name: &str) -> bool {
383        if let Some(plugin) = self.plugins.iter_mut().find(|p| p.name == name) {
384            plugin.enabled = true;
385            true
386        } else {
387            false
388        }
389    }
390
391    /// Disable a plugin
392    pub fn disable_plugin(&mut self, name: &str) -> bool {
393        if let Some(plugin) = self.plugins.iter_mut().find(|p| p.name == name) {
394            plugin.enabled = false;
395            true
396        } else {
397            false
398        }
399    }
400}
401
402impl Default for PluginHost {
403    fn default() -> Self {
404        Self::with_default_dir()
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_plugin_host_creation() {
414        let host = PluginHost::with_default_dir();
415        assert_eq!(host.plugin_count(), 0);
416    }
417
418    #[test]
419    fn test_load_nonexistent_plugin() {
420        let mut host = PluginHost::with_default_dir();
421        let result = host.load("test", Path::new("/nonexistent/plugin.wasm"));
422        assert!(result.is_err());
423    }
424}