Skip to main content

room_daemon/plugin/
loader.rs

1//! Dynamic plugin loader using libloading.
2//!
3//! Scans `~/.room/plugins/` for shared libraries (`.so` on Linux, `.dylib` on
4//! macOS) and loads them via the C ABI entry points defined in
5//! [`room_protocol::plugin::abi`].
6//!
7//! Each loaded plugin goes through three stages:
8//! 1. Load the shared library
9//! 2. Read the `ROOM_PLUGIN_DECLARATION` static to verify API/protocol compat
10//! 3. Call `room_plugin_create` to obtain a `Box<dyn Plugin>`
11//!
12//! On drop, the loader calls `room_plugin_destroy` before unloading the library.
13
14use std::path::{Path, PathBuf};
15
16use room_protocol::plugin::abi::{
17    CreateFn, DestroyFn, PluginDeclaration, CREATE_SYMBOL, DECLARATION_SYMBOL, DESTROY_SYMBOL,
18};
19use room_protocol::plugin::{Plugin, PLUGIN_API_VERSION, PROTOCOL_VERSION};
20
21/// A dynamically loaded plugin and its backing library handle.
22///
23/// The library must outlive the plugin — `Drop` calls the destroy function
24/// before the library is unloaded.
25pub struct LoadedPlugin {
26    plugin: *mut Box<dyn Plugin>,
27    destroy_fn: DestroyFn,
28    _library: libloading::Library,
29    /// Path to the shared library (for diagnostics).
30    pub path: PathBuf,
31}
32
33// SAFETY: The plugin trait object is Send + Sync (required by the Plugin trait),
34// and we only call destroy_fn once in Drop.
35unsafe impl Send for LoadedPlugin {}
36unsafe impl Sync for LoadedPlugin {}
37
38impl std::fmt::Debug for LoadedPlugin {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("LoadedPlugin")
41            .field("path", &self.path)
42            .finish_non_exhaustive()
43    }
44}
45
46impl LoadedPlugin {
47    /// Get a reference to the loaded plugin.
48    pub fn plugin(&self) -> &dyn Plugin {
49        // SAFETY: plugin was returned by CreateFn and has not been destroyed.
50        unsafe { &**self.plugin }
51    }
52
53    /// Consume this wrapper and return the boxed plugin trait object.
54    ///
55    /// The caller takes ownership of the plugin and is responsible for
56    /// ensuring the library outlives it. The destroy function will NOT
57    /// be called — the caller must arrange cleanup.
58    ///
59    /// # Safety
60    ///
61    /// The returned `Box<dyn Plugin>` references vtable entries inside the
62    /// shared library. The library (`_library` field) is dropped when this
63    /// struct is consumed, so the caller must ensure the plugin is dropped
64    /// before the library would be unloaded. In practice, this is safe when
65    /// the plugin is registered into the PluginRegistry and the registry
66    /// is dropped before process exit (which is the normal lifecycle).
67    pub unsafe fn into_boxed_plugin(self) -> Box<dyn Plugin> {
68        let plugin = *Box::from_raw(self.plugin);
69        // Prevent Drop from calling destroy_fn — we transferred ownership.
70        std::mem::forget(self);
71        plugin
72    }
73}
74
75impl Drop for LoadedPlugin {
76    fn drop(&mut self) {
77        // SAFETY: plugin was returned by CreateFn from the same library,
78        // and we only call destroy once (Drop runs exactly once).
79        unsafe {
80            (self.destroy_fn)(self.plugin);
81        }
82    }
83}
84
85/// Errors that can occur when loading a plugin.
86#[derive(Debug)]
87pub enum LoadError {
88    /// Failed to open the shared library.
89    LibraryOpen(String),
90    /// A required symbol was not found.
91    SymbolNotFound(String),
92    /// The plugin's API version does not match the broker's.
93    ApiVersionMismatch { expected: u32, found: u32 },
94    /// The plugin's minimum protocol version is newer than the running broker.
95    ProtocolMismatch { required: String, running: String },
96    /// UTF-8 decoding failed on a declaration string field.
97    InvalidUtf8(String),
98    /// The create function returned a null pointer.
99    CreateReturnedNull,
100}
101
102impl std::fmt::Display for LoadError {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        match self {
105            Self::LibraryOpen(e) => write!(f, "failed to open library: {e}"),
106            Self::SymbolNotFound(s) => write!(f, "symbol not found: {s}"),
107            Self::ApiVersionMismatch { expected, found } => {
108                write!(
109                    f,
110                    "API version mismatch: expected {expected}, found {found}"
111                )
112            }
113            Self::ProtocolMismatch { required, running } => {
114                write!(
115                    f,
116                    "protocol mismatch: plugin requires {required}, running {running}"
117                )
118            }
119            Self::InvalidUtf8(field) => write!(f, "invalid UTF-8 in declaration field: {field}"),
120            Self::CreateReturnedNull => write!(f, "plugin create function returned null"),
121        }
122    }
123}
124
125impl std::error::Error for LoadError {}
126
127/// Load a plugin from a shared library at the given path.
128///
129/// Validates the `PluginDeclaration` against the running broker's API and
130/// protocol versions before calling the create function.
131///
132/// # Safety
133///
134/// Loading a shared library executes its initialization routines, which may
135/// have arbitrary side effects. Only load trusted libraries.
136pub fn load_plugin(path: &Path, config_json: Option<&str>) -> Result<LoadedPlugin, LoadError> {
137    // SAFETY: loading a shared library is inherently unsafe — it runs init
138    // routines. We trust the caller to only load vetted plugin libraries.
139    let library = unsafe { libloading::Library::new(path) }
140        .map_err(|e| LoadError::LibraryOpen(format!("{}: {e}", path.display())))?;
141
142    // Read the declaration static.
143    let declaration: &PluginDeclaration = unsafe {
144        let sym = library
145            .get::<*const PluginDeclaration>(DECLARATION_SYMBOL)
146            .map_err(|e| LoadError::SymbolNotFound(format!("ROOM_PLUGIN_DECLARATION: {e}")))?;
147        &**sym
148    };
149
150    // Validate API version.
151    if declaration.api_version != PLUGIN_API_VERSION {
152        return Err(LoadError::ApiVersionMismatch {
153            expected: PLUGIN_API_VERSION,
154            found: declaration.api_version,
155        });
156    }
157
158    // Validate protocol version (plugin's minimum must not exceed ours).
159    let min_protocol = unsafe {
160        declaration
161            .min_protocol()
162            .map_err(|_| LoadError::InvalidUtf8("min_protocol".to_owned()))?
163    };
164    if !protocol_satisfies(min_protocol, PROTOCOL_VERSION) {
165        return Err(LoadError::ProtocolMismatch {
166            required: min_protocol.to_owned(),
167            running: PROTOCOL_VERSION.to_owned(),
168        });
169    }
170
171    // Look up create and destroy functions.
172    let create_fn: CreateFn = unsafe {
173        *library
174            .get::<CreateFn>(CREATE_SYMBOL)
175            .map_err(|e| LoadError::SymbolNotFound(format!("room_plugin_create: {e}")))?
176    };
177    let destroy_fn: DestroyFn = unsafe {
178        *library
179            .get::<DestroyFn>(DESTROY_SYMBOL)
180            .map_err(|e| LoadError::SymbolNotFound(format!("room_plugin_destroy: {e}")))?
181    };
182
183    // Call the create function.
184    let (config_ptr, config_len) = match config_json {
185        Some(s) => (s.as_ptr(), s.len()),
186        None => (std::ptr::null(), 0),
187    };
188    let plugin = unsafe { create_fn(config_ptr, config_len) };
189    if plugin.is_null() {
190        return Err(LoadError::CreateReturnedNull);
191    }
192
193    Ok(LoadedPlugin {
194        plugin,
195        destroy_fn,
196        _library: library,
197        path: path.to_owned(),
198    })
199}
200
201/// Scan a directory for plugin shared libraries and load each one.
202///
203/// Returns successfully loaded plugins and logs warnings for any that fail.
204/// An empty or nonexistent directory returns an empty vec (not an error).
205pub fn scan_plugin_dir(dir: &Path) -> Vec<LoadedPlugin> {
206    let entries = match std::fs::read_dir(dir) {
207        Ok(e) => e,
208        Err(_) => return Vec::new(),
209    };
210
211    let mut plugins = Vec::new();
212    for entry in entries.flatten() {
213        let path = entry.path();
214        if !is_shared_lib(&path) {
215            continue;
216        }
217        match load_plugin(&path, None) {
218            Ok(loaded) => {
219                let name = loaded.plugin().name().to_owned();
220                eprintln!(
221                    "[plugin] loaded external plugin '{}' from {}",
222                    name,
223                    path.display()
224                );
225                plugins.push(loaded);
226            }
227            Err(e) => {
228                eprintln!("[plugin] failed to load plugin {}: {e}", path.display());
229            }
230        }
231    }
232    plugins
233}
234
235/// Check if a path looks like a shared library (`.so` or `.dylib`).
236fn is_shared_lib(path: &Path) -> bool {
237    path.extension()
238        .and_then(|e| e.to_str())
239        .is_some_and(|ext| ext == "so" || ext == "dylib")
240}
241
242/// Check if `running` satisfies `required` (i.e. running >= required).
243///
244/// Compares major.minor.patch numerically. Returns true if the running
245/// version is greater than or equal to the required version.
246fn protocol_satisfies(required: &str, running: &str) -> bool {
247    let parse = |s: &str| -> Option<(u64, u64, u64)> {
248        let parts: Vec<&str> = s.split('.').collect();
249        if parts.len() < 3 {
250            return None;
251        }
252        Some((
253            parts[0].parse().ok()?,
254            parts[1].parse().ok()?,
255            parts[2].parse().ok()?,
256        ))
257    };
258
259    match (parse(required), parse(running)) {
260        (Some(req), Some(run)) => run >= req,
261        // If either fails to parse, be permissive — let the plugin load.
262        _ => true,
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn is_shared_lib_recognizes_so() {
272        assert!(is_shared_lib(Path::new("/tmp/plugins/myplugin.so")));
273    }
274
275    #[test]
276    fn is_shared_lib_recognizes_dylib() {
277        assert!(is_shared_lib(Path::new("/tmp/plugins/myplugin.dylib")));
278    }
279
280    #[test]
281    fn is_shared_lib_rejects_other_extensions() {
282        assert!(!is_shared_lib(Path::new("/tmp/plugins/myplugin.toml")));
283        assert!(!is_shared_lib(Path::new("/tmp/plugins/myplugin.json")));
284        assert!(!is_shared_lib(Path::new("/tmp/plugins/myplugin.rs")));
285        assert!(!is_shared_lib(Path::new("/tmp/plugins/README")));
286    }
287
288    #[test]
289    fn is_shared_lib_rejects_no_extension() {
290        assert!(!is_shared_lib(Path::new("/tmp/plugins/myplugin")));
291    }
292
293    #[test]
294    fn protocol_satisfies_exact_match() {
295        assert!(protocol_satisfies("3.4.0", "3.4.0"));
296    }
297
298    #[test]
299    fn protocol_satisfies_running_newer() {
300        assert!(protocol_satisfies("3.0.0", "3.4.0"));
301        assert!(protocol_satisfies("2.0.0", "3.4.0"));
302    }
303
304    #[test]
305    fn protocol_satisfies_running_older_fails() {
306        assert!(!protocol_satisfies("4.0.0", "3.4.0"));
307        assert!(!protocol_satisfies("3.5.0", "3.4.0"));
308    }
309
310    #[test]
311    fn protocol_satisfies_zero_always_passes() {
312        assert!(protocol_satisfies("0.0.0", "3.4.0"));
313    }
314
315    #[test]
316    fn protocol_satisfies_unparseable_is_permissive() {
317        assert!(protocol_satisfies("bad", "3.4.0"));
318        assert!(protocol_satisfies("3.4.0", "bad"));
319    }
320
321    #[test]
322    fn load_plugin_nonexistent_path_returns_error() {
323        let result = load_plugin(Path::new("/nonexistent/plugin.so"), None);
324        assert!(result.is_err());
325        let err = result.unwrap_err();
326        assert!(matches!(err, LoadError::LibraryOpen(_)));
327    }
328
329    #[test]
330    fn scan_plugin_dir_empty_dir_returns_empty() {
331        let dir = tempfile::TempDir::new().unwrap();
332        let plugins = scan_plugin_dir(dir.path());
333        assert!(plugins.is_empty());
334    }
335
336    #[test]
337    fn scan_plugin_dir_nonexistent_returns_empty() {
338        let plugins = scan_plugin_dir(Path::new("/nonexistent/plugins"));
339        assert!(plugins.is_empty());
340    }
341
342    #[test]
343    fn scan_plugin_dir_skips_non_library_files() {
344        let dir = tempfile::TempDir::new().unwrap();
345        std::fs::write(dir.path().join("readme.txt"), "not a plugin").unwrap();
346        std::fs::write(dir.path().join("config.toml"), "[plugin]").unwrap();
347        let plugins = scan_plugin_dir(dir.path());
348        assert!(plugins.is_empty());
349    }
350
351    #[test]
352    fn load_error_display_messages() {
353        let e = LoadError::LibraryOpen("no such file".into());
354        assert!(e.to_string().contains("no such file"));
355
356        let e = LoadError::ApiVersionMismatch {
357            expected: 1,
358            found: 2,
359        };
360        assert!(e.to_string().contains("expected 1"));
361        assert!(e.to_string().contains("found 2"));
362
363        let e = LoadError::ProtocolMismatch {
364            required: "4.0.0".into(),
365            running: "3.4.0".into(),
366        };
367        assert!(e.to_string().contains("4.0.0"));
368
369        let e = LoadError::CreateReturnedNull;
370        assert!(e.to_string().contains("null"));
371    }
372}