spec_ai_plugin/
loader.rs

1//! Plugin discovery and loading
2
3use crate::abi::{PluginModuleRef, PluginToolRef, PLUGIN_API_VERSION};
4use crate::error::PluginError;
5use abi_stable::library::RootModule;
6use anyhow::Result;
7use std::path::{Path, PathBuf};
8use tracing::{debug, error, info, warn};
9
10/// Statistics from loading plugins
11#[derive(Debug, Default, Clone)]
12pub struct LoadStats {
13    /// Total plugin files found
14    pub total: usize,
15    /// Successfully loaded plugins
16    pub loaded: usize,
17    /// Failed to load plugins
18    pub failed: usize,
19    /// Total tools loaded across all plugins
20    pub tools_loaded: usize,
21}
22
23/// A loaded plugin with its metadata
24pub struct LoadedPlugin {
25    /// Path to the plugin library
26    pub path: PathBuf,
27    /// Plugin name
28    pub name: String,
29    /// Tools provided by this plugin
30    pub tools: Vec<PluginToolRef>,
31}
32
33impl std::fmt::Debug for LoadedPlugin {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("LoadedPlugin")
36            .field("path", &self.path)
37            .field("name", &self.name)
38            .field("tools_count", &self.tools.len())
39            .finish()
40    }
41}
42
43/// Plugin loader that discovers and loads plugin libraries
44pub struct PluginLoader {
45    plugins: Vec<LoadedPlugin>,
46}
47
48impl PluginLoader {
49    /// Create a new empty plugin loader
50    pub fn new() -> Self {
51        Self {
52            plugins: Vec::new(),
53        }
54    }
55
56    /// Load all plugins from a directory
57    ///
58    /// Scans the directory for dynamic library files (.dylib on macOS, .so on Linux,
59    /// .dll on Windows) and attempts to load each one as a plugin.
60    ///
61    /// # Arguments
62    /// * `dir` - Directory to scan for plugins
63    ///
64    /// # Returns
65    /// Statistics about the loading process
66    pub fn load_directory(&mut self, dir: &Path) -> Result<LoadStats> {
67        let mut stats = LoadStats::default();
68
69        if !dir.exists() {
70            info!("Plugin directory does not exist: {}", dir.display());
71            return Ok(stats);
72        }
73
74        if !dir.is_dir() {
75            return Err(PluginError::NotADirectory(dir.to_path_buf()).into());
76        }
77
78        info!("Scanning plugin directory: {}", dir.display());
79
80        for entry in walkdir::WalkDir::new(dir)
81            .max_depth(1)
82            .into_iter()
83            .filter_map(|e| e.ok())
84        {
85            let path = entry.path();
86
87            if !Self::is_plugin_library(path) {
88                continue;
89            }
90
91            stats.total += 1;
92
93            match self.load_plugin(path) {
94                Ok(tool_count) => {
95                    stats.loaded += 1;
96                    stats.tools_loaded += tool_count;
97                    info!("Loaded plugin: {} ({} tools)", path.display(), tool_count);
98                }
99                Err(e) => {
100                    stats.failed += 1;
101                    error!("Failed to load plugin {}: {}", path.display(), e);
102                }
103            }
104        }
105
106        Ok(stats)
107    }
108
109    /// Load a single plugin from a file
110    fn load_plugin(&mut self, path: &Path) -> Result<usize> {
111        debug!("Loading plugin from: {}", path.display());
112
113        // Load the root module using abi_stable
114        let module =
115            PluginModuleRef::load_from_file(path).map_err(|e| PluginError::LoadFailed {
116                path: path.to_path_buf(),
117                message: e.to_string(),
118            })?;
119
120        // Check API version compatibility
121        let plugin_version = (module.api_version())();
122        if plugin_version != PLUGIN_API_VERSION {
123            return Err(PluginError::VersionMismatch {
124                expected: PLUGIN_API_VERSION,
125                found: plugin_version,
126                path: path.to_path_buf(),
127            }
128            .into());
129        }
130
131        let plugin_name = (module.plugin_name())().to_string();
132        debug!("Plugin '{}' passed version check", plugin_name);
133
134        // Check for duplicate plugin names
135        if self.plugins.iter().any(|p| p.name == plugin_name) {
136            return Err(PluginError::DuplicatePlugin(plugin_name).into());
137        }
138
139        // Get tools from the plugin
140        let tool_refs = (module.get_tools())();
141        let tool_count = tool_refs.len();
142
143        // Collect tool refs into a Vec
144        let tools: Vec<PluginToolRef> = tool_refs.into_iter().collect();
145
146        // Call initialize on each tool if it has one
147        for tool in &tools {
148            if let Some(init) = tool.initialize {
149                let context = "{}"; // Empty context for now
150                if !init(context.into()) {
151                    warn!(
152                        "Tool '{}' initialization failed",
153                        (tool.info)().name.as_str()
154                    );
155                }
156            }
157        }
158
159        self.plugins.push(LoadedPlugin {
160            path: path.to_path_buf(),
161            name: plugin_name,
162            tools,
163        });
164
165        Ok(tool_count)
166    }
167
168    /// Check if a path is a plugin library based on extension
169    fn is_plugin_library(path: &Path) -> bool {
170        if !path.is_file() {
171            return false;
172        }
173
174        let Some(ext) = path.extension() else {
175            return false;
176        };
177
178        #[cfg(target_os = "macos")]
179        let expected = "dylib";
180
181        #[cfg(target_os = "linux")]
182        let expected = "so";
183
184        #[cfg(target_os = "windows")]
185        let expected = "dll";
186
187        #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
188        let expected = "so"; // Default to .so for unknown platforms
189
190        ext == expected
191    }
192
193    /// Get all loaded plugins
194    pub fn plugins(&self) -> &[LoadedPlugin] {
195        &self.plugins
196    }
197
198    /// Get all tools from all loaded plugins as an iterator
199    pub fn all_tools(&self) -> impl Iterator<Item = (PluginToolRef, &str)> {
200        self.plugins
201            .iter()
202            .flat_map(|p| p.tools.iter().map(move |t| (*t, p.name.as_str())))
203    }
204
205    /// Get the number of loaded plugins
206    pub fn plugin_count(&self) -> usize {
207        self.plugins.len()
208    }
209
210    /// Get the total number of tools across all plugins
211    pub fn tool_count(&self) -> usize {
212        self.plugins.iter().map(|p| p.tools.len()).sum()
213    }
214}
215
216impl Default for PluginLoader {
217    fn default() -> Self {
218        Self::new()
219    }
220}
221
222/// Expand tilde (~) in paths to the home directory
223pub fn expand_tilde(path: &Path) -> PathBuf {
224    if let Ok(path_str) = path.to_str().ok_or(()) {
225        if path_str.starts_with("~/") {
226            if let Some(home) = dirs_home() {
227                return home.join(&path_str[2..]);
228            }
229        }
230    }
231    path.to_path_buf()
232}
233
234/// Get the user's home directory
235fn dirs_home() -> Option<PathBuf> {
236    #[cfg(target_os = "windows")]
237    {
238        std::env::var("USERPROFILE").ok().map(PathBuf::from)
239    }
240    #[cfg(not(target_os = "windows"))]
241    {
242        std::env::var("HOME").ok().map(PathBuf::from)
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use std::path::Path;
250
251    #[test]
252    fn test_is_plugin_library() {
253        // Note: is_plugin_library checks if the path is a file first,
254        // so these tests pass non-existent paths which will return false.
255        // The extension check only happens if the file exists.
256
257        // Non-existent paths always return false (file check first)
258        assert!(!PluginLoader::is_plugin_library(Path::new(
259            "/tmp/nonexistent/libplugin.dylib"
260        )));
261
262        // Non-library extensions also return false
263        assert!(!PluginLoader::is_plugin_library(Path::new(
264            "/tmp/test/plugin.txt"
265        )));
266        assert!(!PluginLoader::is_plugin_library(Path::new(
267            "/tmp/test/plugin"
268        )));
269    }
270
271    #[test]
272    fn test_expand_tilde() {
273        let home = dirs_home().unwrap_or_else(|| PathBuf::from("/home/user"));
274
275        let expanded = expand_tilde(Path::new("~/test"));
276        assert!(expanded.starts_with(&home) || expanded == Path::new("~/test"));
277
278        // Non-tilde paths should be unchanged
279        let absolute = expand_tilde(Path::new("/absolute/path"));
280        assert_eq!(absolute, Path::new("/absolute/path"));
281    }
282
283    #[test]
284    fn test_load_stats_default() {
285        let stats = LoadStats::default();
286        assert_eq!(stats.total, 0);
287        assert_eq!(stats.loaded, 0);
288        assert_eq!(stats.failed, 0);
289        assert_eq!(stats.tools_loaded, 0);
290    }
291}