1use std::path::{Path, PathBuf};
15
16use room_protocol::plugin::abi::{
17 CreateFn, DestroyFn, PluginDeclaration, CREATE_SYMBOL, DECLARATION_SYMBOL, DESTROY_SYMBOL,
18};
19use room_protocol::plugin::{Plugin, PLUGIN_API_VERSION, PROTOCOL_VERSION};
20
21pub struct LoadedPlugin {
26 plugin: *mut Box<dyn Plugin>,
27 destroy_fn: DestroyFn,
28 _library: libloading::Library,
29 pub path: PathBuf,
31}
32
33unsafe impl Send for LoadedPlugin {}
36unsafe impl Sync for LoadedPlugin {}
37
38impl std::fmt::Debug for LoadedPlugin {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("LoadedPlugin")
41 .field("path", &self.path)
42 .finish_non_exhaustive()
43 }
44}
45
46impl LoadedPlugin {
47 pub fn plugin(&self) -> &dyn Plugin {
49 unsafe { &**self.plugin }
51 }
52
53 pub unsafe fn into_boxed_plugin(self) -> Box<dyn Plugin> {
68 let plugin = *Box::from_raw(self.plugin);
69 std::mem::forget(self);
71 plugin
72 }
73}
74
75impl Drop for LoadedPlugin {
76 fn drop(&mut self) {
77 unsafe {
80 (self.destroy_fn)(self.plugin);
81 }
82 }
83}
84
85#[derive(Debug)]
87pub enum LoadError {
88 LibraryOpen(String),
90 SymbolNotFound(String),
92 ApiVersionMismatch { expected: u32, found: u32 },
94 ProtocolMismatch { required: String, running: String },
96 InvalidUtf8(String),
98 CreateReturnedNull,
100}
101
102impl std::fmt::Display for LoadError {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 match self {
105 Self::LibraryOpen(e) => write!(f, "failed to open library: {e}"),
106 Self::SymbolNotFound(s) => write!(f, "symbol not found: {s}"),
107 Self::ApiVersionMismatch { expected, found } => {
108 write!(
109 f,
110 "API version mismatch: expected {expected}, found {found}"
111 )
112 }
113 Self::ProtocolMismatch { required, running } => {
114 write!(
115 f,
116 "protocol mismatch: plugin requires {required}, running {running}"
117 )
118 }
119 Self::InvalidUtf8(field) => write!(f, "invalid UTF-8 in declaration field: {field}"),
120 Self::CreateReturnedNull => write!(f, "plugin create function returned null"),
121 }
122 }
123}
124
125impl std::error::Error for LoadError {}
126
127pub fn load_plugin(path: &Path, config_json: Option<&str>) -> Result<LoadedPlugin, LoadError> {
137 let library = unsafe { libloading::Library::new(path) }
140 .map_err(|e| LoadError::LibraryOpen(format!("{}: {e}", path.display())))?;
141
142 let declaration: &PluginDeclaration = unsafe {
144 let sym = library
145 .get::<*const PluginDeclaration>(DECLARATION_SYMBOL)
146 .map_err(|e| LoadError::SymbolNotFound(format!("ROOM_PLUGIN_DECLARATION: {e}")))?;
147 &**sym
148 };
149
150 if declaration.api_version != PLUGIN_API_VERSION {
152 return Err(LoadError::ApiVersionMismatch {
153 expected: PLUGIN_API_VERSION,
154 found: declaration.api_version,
155 });
156 }
157
158 let min_protocol = unsafe {
160 declaration
161 .min_protocol()
162 .map_err(|_| LoadError::InvalidUtf8("min_protocol".to_owned()))?
163 };
164 if !protocol_satisfies(min_protocol, PROTOCOL_VERSION) {
165 return Err(LoadError::ProtocolMismatch {
166 required: min_protocol.to_owned(),
167 running: PROTOCOL_VERSION.to_owned(),
168 });
169 }
170
171 let create_fn: CreateFn = unsafe {
173 *library
174 .get::<CreateFn>(CREATE_SYMBOL)
175 .map_err(|e| LoadError::SymbolNotFound(format!("room_plugin_create: {e}")))?
176 };
177 let destroy_fn: DestroyFn = unsafe {
178 *library
179 .get::<DestroyFn>(DESTROY_SYMBOL)
180 .map_err(|e| LoadError::SymbolNotFound(format!("room_plugin_destroy: {e}")))?
181 };
182
183 let (config_ptr, config_len) = match config_json {
185 Some(s) => (s.as_ptr(), s.len()),
186 None => (std::ptr::null(), 0),
187 };
188 let plugin = unsafe { create_fn(config_ptr, config_len) };
189 if plugin.is_null() {
190 return Err(LoadError::CreateReturnedNull);
191 }
192
193 Ok(LoadedPlugin {
194 plugin,
195 destroy_fn,
196 _library: library,
197 path: path.to_owned(),
198 })
199}
200
201pub fn scan_plugin_dir(dir: &Path) -> Vec<LoadedPlugin> {
206 let entries = match std::fs::read_dir(dir) {
207 Ok(e) => e,
208 Err(_) => return Vec::new(),
209 };
210
211 let mut plugins = Vec::new();
212 for entry in entries.flatten() {
213 let path = entry.path();
214 if !is_shared_lib(&path) {
215 continue;
216 }
217 match load_plugin(&path, None) {
218 Ok(loaded) => {
219 let name = loaded.plugin().name().to_owned();
220 eprintln!(
221 "[plugin] loaded external plugin '{}' from {}",
222 name,
223 path.display()
224 );
225 plugins.push(loaded);
226 }
227 Err(e) => {
228 eprintln!("[plugin] failed to load plugin {}: {e}", path.display());
229 }
230 }
231 }
232 plugins
233}
234
235fn is_shared_lib(path: &Path) -> bool {
237 path.extension()
238 .and_then(|e| e.to_str())
239 .is_some_and(|ext| ext == "so" || ext == "dylib")
240}
241
242fn protocol_satisfies(required: &str, running: &str) -> bool {
247 let parse = |s: &str| -> Option<(u64, u64, u64)> {
248 let parts: Vec<&str> = s.split('.').collect();
249 if parts.len() < 3 {
250 return None;
251 }
252 Some((
253 parts[0].parse().ok()?,
254 parts[1].parse().ok()?,
255 parts[2].parse().ok()?,
256 ))
257 };
258
259 match (parse(required), parse(running)) {
260 (Some(req), Some(run)) => run >= req,
261 _ => true,
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn is_shared_lib_recognizes_so() {
272 assert!(is_shared_lib(Path::new("/tmp/plugins/myplugin.so")));
273 }
274
275 #[test]
276 fn is_shared_lib_recognizes_dylib() {
277 assert!(is_shared_lib(Path::new("/tmp/plugins/myplugin.dylib")));
278 }
279
280 #[test]
281 fn is_shared_lib_rejects_other_extensions() {
282 assert!(!is_shared_lib(Path::new("/tmp/plugins/myplugin.toml")));
283 assert!(!is_shared_lib(Path::new("/tmp/plugins/myplugin.json")));
284 assert!(!is_shared_lib(Path::new("/tmp/plugins/myplugin.rs")));
285 assert!(!is_shared_lib(Path::new("/tmp/plugins/README")));
286 }
287
288 #[test]
289 fn is_shared_lib_rejects_no_extension() {
290 assert!(!is_shared_lib(Path::new("/tmp/plugins/myplugin")));
291 }
292
293 #[test]
294 fn protocol_satisfies_exact_match() {
295 assert!(protocol_satisfies("3.4.0", "3.4.0"));
296 }
297
298 #[test]
299 fn protocol_satisfies_running_newer() {
300 assert!(protocol_satisfies("3.0.0", "3.4.0"));
301 assert!(protocol_satisfies("2.0.0", "3.4.0"));
302 }
303
304 #[test]
305 fn protocol_satisfies_running_older_fails() {
306 assert!(!protocol_satisfies("4.0.0", "3.4.0"));
307 assert!(!protocol_satisfies("3.5.0", "3.4.0"));
308 }
309
310 #[test]
311 fn protocol_satisfies_zero_always_passes() {
312 assert!(protocol_satisfies("0.0.0", "3.4.0"));
313 }
314
315 #[test]
316 fn protocol_satisfies_unparseable_is_permissive() {
317 assert!(protocol_satisfies("bad", "3.4.0"));
318 assert!(protocol_satisfies("3.4.0", "bad"));
319 }
320
321 #[test]
322 fn load_plugin_nonexistent_path_returns_error() {
323 let result = load_plugin(Path::new("/nonexistent/plugin.so"), None);
324 assert!(result.is_err());
325 let err = result.unwrap_err();
326 assert!(matches!(err, LoadError::LibraryOpen(_)));
327 }
328
329 #[test]
330 fn scan_plugin_dir_empty_dir_returns_empty() {
331 let dir = tempfile::TempDir::new().unwrap();
332 let plugins = scan_plugin_dir(dir.path());
333 assert!(plugins.is_empty());
334 }
335
336 #[test]
337 fn scan_plugin_dir_nonexistent_returns_empty() {
338 let plugins = scan_plugin_dir(Path::new("/nonexistent/plugins"));
339 assert!(plugins.is_empty());
340 }
341
342 #[test]
343 fn scan_plugin_dir_skips_non_library_files() {
344 let dir = tempfile::TempDir::new().unwrap();
345 std::fs::write(dir.path().join("readme.txt"), "not a plugin").unwrap();
346 std::fs::write(dir.path().join("config.toml"), "[plugin]").unwrap();
347 let plugins = scan_plugin_dir(dir.path());
348 assert!(plugins.is_empty());
349 }
350
351 #[test]
352 fn load_error_display_messages() {
353 let e = LoadError::LibraryOpen("no such file".into());
354 assert!(e.to_string().contains("no such file"));
355
356 let e = LoadError::ApiVersionMismatch {
357 expected: 1,
358 found: 2,
359 };
360 assert!(e.to_string().contains("expected 1"));
361 assert!(e.to_string().contains("found 2"));
362
363 let e = LoadError::ProtocolMismatch {
364 required: "4.0.0".into(),
365 running: "3.4.0".into(),
366 };
367 assert!(e.to_string().contains("4.0.0"));
368
369 let e = LoadError::CreateReturnedNull;
370 assert!(e.to_string().contains("null"));
371 }
372}