Skip to main content

rustbridge_consumer/
loader.rs

1//! Plugin loader for dynamically loading rustbridge plugins.
2
3use crate::error::{ConsumerError, ConsumerResult};
4use crate::ffi_bindings::{
5    FfiPluginHandle, LogCallback, PluginCallFn, PluginCallRawFn, PluginCreateFn,
6    PluginFreeBufferFn, PluginGetRejectedCountFn, PluginGetStateFn, PluginInitFn,
7    PluginSetLogLevelFn, PluginShutdownFn, RbResponseFreeFn,
8};
9use crate::plugin::NativePlugin;
10use libloading::Library;
11use rustbridge_bundle::BundleLoader;
12use rustbridge_core::{LogLevel, PluginConfig};
13use std::ffi::c_char;
14use std::path::Path;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicU64, Ordering};
17use tracing::debug;
18
19/// Monotonic counter to ensure each bundle extraction gets a unique directory.
20/// This prevents SIGBUS when multiple threads extract to the same path concurrently
21/// (one thread truncates the .so file while another has it mmap'd).
22static EXTRACT_INSTANCE: AtomicU64 = AtomicU64::new(0);
23
24/// Rust-friendly log callback type.
25///
26/// This callback receives log messages from the plugin.
27pub type LogCallbackFn = Arc<dyn Fn(LogLevel, &str, &str) + Send + Sync>;
28
29// Global log callback storage (one per thread).
30// This is needed because the FFI callback cannot capture state.
31thread_local! {
32    static LOG_CALLBACK: std::cell::RefCell<Option<LogCallbackFn>> = const { std::cell::RefCell::new(None) };
33}
34
35/// Set the thread-local log callback.
36fn set_log_callback(callback: Option<LogCallbackFn>) {
37    LOG_CALLBACK.with(|cb| {
38        *cb.borrow_mut() = callback;
39    });
40}
41
42/// FFI-compatible log callback that forwards to the Rust callback.
43///
44/// # Safety
45/// - `target` must be a valid null-terminated C string or null
46/// - `message` must be valid for `message_len` bytes or null
47unsafe extern "C" fn ffi_log_callback(
48    level: u8,
49    target: *const c_char,
50    message: *const u8,
51    message_len: usize,
52) {
53    LOG_CALLBACK.with(|cb| {
54        if let Some(callback) = cb.borrow().as_ref() {
55            let log_level = LogLevel::from_u8(level);
56
57            // SAFETY: target is a valid null-terminated C string
58            let target_str = if target.is_null() {
59                ""
60            } else {
61                unsafe { std::ffi::CStr::from_ptr(target) }
62                    .to_str()
63                    .unwrap_or("")
64            };
65
66            // SAFETY: message is valid for message_len bytes (NOT null-terminated)
67            let message_str = if message.is_null() || message_len == 0 {
68                ""
69            } else {
70                let bytes = unsafe { std::slice::from_raw_parts(message, message_len) };
71                std::str::from_utf8(bytes).unwrap_or("")
72            };
73
74            callback(log_level, target_str, message_str);
75        }
76    });
77}
78
79/// Loader for native plugins.
80///
81/// Provides methods to load plugins from shared libraries or bundles.
82pub struct NativePluginLoader;
83
84impl NativePluginLoader {
85    /// Load a plugin from a shared library path.
86    ///
87    /// Uses default configuration and no log callback.
88    ///
89    /// # Arguments
90    ///
91    /// * `path` - Path to the shared library (.so, .dylib, or .dll)
92    ///
93    /// # Example
94    ///
95    /// ```ignore
96    /// let plugin = NativePluginLoader::load("target/release/libmy_plugin.so")?;
97    /// ```
98    pub fn load<P: AsRef<Path>>(path: P) -> ConsumerResult<NativePlugin> {
99        Self::load_with_config(path, &PluginConfig::default(), None)
100    }
101
102    /// Load a plugin with custom configuration.
103    ///
104    /// # Arguments
105    ///
106    /// * `path` - Path to the shared library
107    /// * `config` - Plugin configuration
108    /// * `log_callback` - Optional callback to receive log messages
109    ///
110    /// # Example
111    ///
112    /// ```ignore
113    /// let config = PluginConfig::default();
114    /// let log_callback: LogCallbackFn = Arc::new(|level, target, msg| {
115    ///     println!("[{level}] {target}: {msg}");
116    /// });
117    ///
118    /// let plugin = NativePluginLoader::load_with_config(
119    ///     "target/release/libmy_plugin.so",
120    ///     &config,
121    ///     Some(log_callback),
122    /// )?;
123    /// ```
124    pub fn load_with_config<P: AsRef<Path>>(
125        path: P,
126        config: &PluginConfig,
127        log_callback: Option<LogCallbackFn>,
128    ) -> ConsumerResult<NativePlugin> {
129        let path = path.as_ref();
130        debug!("Loading plugin from: {}", path.display());
131
132        // Load the shared library
133        // SAFETY: We're loading a shared library which requires unsafe
134        let library = unsafe { Library::new(path) }?;
135
136        // Load required symbols
137        let plugin_create: PluginCreateFn = unsafe { *library.get(b"plugin_create\0")? };
138        let plugin_init: PluginInitFn = unsafe { *library.get(b"plugin_init\0")? };
139        let plugin_call: PluginCallFn = unsafe { *library.get(b"plugin_call\0")? };
140        let plugin_shutdown: PluginShutdownFn = unsafe { *library.get(b"plugin_shutdown\0")? };
141        let plugin_get_state: PluginGetStateFn = unsafe { *library.get(b"plugin_get_state\0")? };
142        let plugin_get_rejected_count: PluginGetRejectedCountFn =
143            unsafe { *library.get(b"plugin_get_rejected_count\0")? };
144        let plugin_set_log_level: PluginSetLogLevelFn =
145            unsafe { *library.get(b"plugin_set_log_level\0")? };
146        let plugin_free_buffer: PluginFreeBufferFn =
147            unsafe { *library.get(b"plugin_free_buffer\0")? };
148
149        // Load optional binary transport symbols
150        let plugin_call_raw: Option<PluginCallRawFn> =
151            unsafe { library.get(b"plugin_call_raw\0").ok().map(|s| *s) };
152        let rb_response_free: Option<RbResponseFreeFn> =
153            unsafe { library.get(b"rb_response_free\0").ok().map(|s| *s) };
154
155        // Set up log callback if provided
156        set_log_callback(log_callback);
157        let ffi_callback: LogCallback = Some(ffi_log_callback);
158
159        // Create the plugin instance
160        // SAFETY: plugin_create returns a valid pointer or null
161        let plugin_ptr = unsafe { plugin_create() };
162        if plugin_ptr.is_null() {
163            return Err(ConsumerError::NullHandle);
164        }
165
166        // Serialize config to JSON
167        let config_json = serde_json::to_vec(config)?;
168
169        // Initialize the plugin
170        // SAFETY: plugin_ptr is valid, config_json is valid for its length
171        let handle: FfiPluginHandle = unsafe {
172            plugin_init(
173                plugin_ptr,
174                config_json.as_ptr(),
175                config_json.len(),
176                ffi_callback,
177            )
178        };
179
180        if handle.is_null() {
181            return Err(ConsumerError::NullHandle);
182        }
183
184        debug!("Plugin initialized with handle: {:?}", handle);
185
186        // SAFETY: All pointers are valid and came from the library
187        Ok(unsafe {
188            NativePlugin::new(
189                library,
190                handle,
191                plugin_call,
192                plugin_call_raw,
193                plugin_shutdown,
194                plugin_get_state,
195                plugin_get_rejected_count,
196                plugin_set_log_level,
197                plugin_free_buffer,
198                rb_response_free,
199            )
200        })
201    }
202
203    /// Load a plugin from a bundle file.
204    ///
205    /// Extracts the library for the current platform and loads it.
206    ///
207    /// # Arguments
208    ///
209    /// * `bundle_path` - Path to the .rbp bundle file
210    ///
211    /// # Example
212    ///
213    /// ```ignore
214    /// let plugin = NativePluginLoader::load_bundle("my-plugin-1.0.0.rbp")?;
215    /// ```
216    pub fn load_bundle<P: AsRef<Path>>(bundle_path: P) -> ConsumerResult<NativePlugin> {
217        Self::load_bundle_with_config(bundle_path, &PluginConfig::default(), None)
218    }
219
220    /// Load a plugin from a bundle with custom configuration.
221    ///
222    /// # Arguments
223    ///
224    /// * `bundle_path` - Path to the .rbp bundle file
225    /// * `config` - Plugin configuration
226    /// * `log_callback` - Optional callback to receive log messages
227    ///
228    /// # Example
229    ///
230    /// ```ignore
231    /// let config = PluginConfig::default();
232    /// let plugin = NativePluginLoader::load_bundle_with_config(
233    ///     "my-plugin-1.0.0.rbp",
234    ///     &config,
235    ///     None,
236    /// )?;
237    /// ```
238    pub fn load_bundle_with_config<P: AsRef<Path>>(
239        bundle_path: P,
240        config: &PluginConfig,
241        log_callback: Option<LogCallbackFn>,
242    ) -> ConsumerResult<NativePlugin> {
243        let bundle_path = bundle_path.as_ref();
244        debug!("Loading bundle from: {}", bundle_path.display());
245
246        // Open and validate the bundle
247        let mut loader = BundleLoader::open(bundle_path)?;
248
249        // Check platform support
250        if !loader.supports_current_platform() {
251            return Err(ConsumerError::Bundle(
252                rustbridge_bundle::BundleError::UnsupportedPlatform(
253                    "Current platform not supported by bundle".to_string(),
254                ),
255            ));
256        }
257
258        // Each load gets a unique extraction directory to prevent SIGBUS from
259        // concurrent threads overwriting a file that another thread has mmap'd.
260        let instance_id = EXTRACT_INSTANCE.fetch_add(1, Ordering::Relaxed);
261        let extract_dir = bundle_path
262            .parent()
263            .unwrap_or(Path::new("."))
264            .join(".rustbridge-cache")
265            .join(loader.manifest().plugin.name.as_str())
266            .join(loader.manifest().plugin.version.as_str())
267            .join(instance_id.to_string());
268
269        // Extract the library for the current platform
270        let lib_path = loader.extract_library_for_current_platform(&extract_dir)?;
271
272        debug!("Extracted library to: {}", lib_path.display());
273
274        // Load the extracted library
275        Self::load_with_config(lib_path, config, log_callback)
276    }
277
278    /// Load a specific variant from a bundle with custom configuration.
279    ///
280    /// Unlike `load_bundle_with_config` which always extracts the default (release) variant,
281    /// this method extracts the named variant (e.g., "debug", "release").
282    ///
283    /// # Arguments
284    ///
285    /// * `bundle_path` - Path to the .rbp bundle file
286    /// * `variant` - Variant name (e.g., "release", "debug")
287    /// * `config` - Plugin configuration
288    /// * `log_callback` - Optional callback to receive log messages
289    ///
290    /// # Example
291    ///
292    /// ```ignore
293    /// let config = PluginConfig::default();
294    /// let plugin = NativePluginLoader::load_bundle_variant_with_config(
295    ///     "my-plugin-1.0.0.rbp",
296    ///     "debug",
297    ///     &config,
298    ///     None,
299    /// )?;
300    /// ```
301    pub fn load_bundle_variant_with_config<P: AsRef<Path>>(
302        bundle_path: P,
303        variant: &str,
304        config: &PluginConfig,
305        log_callback: Option<LogCallbackFn>,
306    ) -> ConsumerResult<NativePlugin> {
307        let bundle_path = bundle_path.as_ref();
308        debug!(
309            "Loading bundle variant '{}' from: {}",
310            variant,
311            bundle_path.display()
312        );
313
314        // Open and validate the bundle
315        let mut loader = BundleLoader::open(bundle_path)?;
316
317        // Check platform support
318        let platform = rustbridge_bundle::Platform::current().ok_or_else(|| {
319            ConsumerError::Bundle(rustbridge_bundle::BundleError::UnsupportedPlatform(
320                "Current platform is not supported".to_string(),
321            ))
322        })?;
323
324        if !loader.supports_current_platform() {
325            return Err(ConsumerError::Bundle(
326                rustbridge_bundle::BundleError::UnsupportedPlatform(
327                    "Current platform not supported by bundle".to_string(),
328                ),
329            ));
330        }
331
332        // Each load gets a unique extraction directory to prevent SIGBUS from
333        // concurrent threads overwriting a file that another thread has mmap'd.
334        let instance_id = EXTRACT_INSTANCE.fetch_add(1, Ordering::Relaxed);
335        let extract_dir = bundle_path
336            .parent()
337            .unwrap_or(Path::new("."))
338            .join(".rustbridge-cache")
339            .join(loader.manifest().plugin.name.as_str())
340            .join(loader.manifest().plugin.version.as_str())
341            .join(format!("{variant}-{instance_id}"));
342
343        // Extract the specified variant
344        let lib_path = loader.extract_library_variant(platform, variant, &extract_dir)?;
345
346        debug!("Extracted variant library to: {}", lib_path.display());
347
348        // Load the extracted library
349        Self::load_with_config(lib_path, config, log_callback)
350    }
351
352    /// Load a plugin from a bundle to a specific extraction directory.
353    ///
354    /// This is useful when you want to control where the library is extracted.
355    ///
356    /// # Arguments
357    ///
358    /// * `bundle_path` - Path to the .rbp bundle file
359    /// * `extract_dir` - Directory to extract the library to
360    /// * `config` - Plugin configuration
361    /// * `log_callback` - Optional callback to receive log messages
362    pub fn load_bundle_to_dir<P: AsRef<Path>, Q: AsRef<Path>>(
363        bundle_path: P,
364        extract_dir: Q,
365        config: &PluginConfig,
366        log_callback: Option<LogCallbackFn>,
367    ) -> ConsumerResult<NativePlugin> {
368        Self::load_bundle_verified(
369            bundle_path,
370            Some(extract_dir),
371            config,
372            log_callback,
373            false,
374            None,
375        )
376    }
377
378    /// Load a plugin from a bundle with signature verification.
379    ///
380    /// # Arguments
381    ///
382    /// * `bundle_path` - Path to the .rbp bundle file
383    /// * `config` - Plugin configuration
384    /// * `log_callback` - Optional callback to receive log messages
385    /// * `verify_signatures` - Whether to verify minisign signatures
386    /// * `public_key_override` - Optional public key to use instead of manifest's key
387    ///
388    /// # Example
389    ///
390    /// ```ignore
391    /// // Load with signature verification (recommended for production)
392    /// let plugin = NativePluginLoader::load_bundle_with_verification(
393    ///     "my-plugin-1.0.0.rbp",
394    ///     &PluginConfig::default(),
395    ///     None,
396    ///     true,  // verify signatures
397    ///     None,  // use manifest's public key
398    /// )?;
399    /// ```
400    pub fn load_bundle_with_verification<P: AsRef<Path>>(
401        bundle_path: P,
402        config: &PluginConfig,
403        log_callback: Option<LogCallbackFn>,
404        verify_signatures: bool,
405        public_key_override: Option<&str>,
406    ) -> ConsumerResult<NativePlugin> {
407        Self::load_bundle_verified(
408            bundle_path,
409            None::<&Path>,
410            config,
411            log_callback,
412            verify_signatures,
413            public_key_override,
414        )
415    }
416
417    /// Internal method to load a bundle with all options.
418    fn load_bundle_verified<P: AsRef<Path>, Q: AsRef<Path>>(
419        bundle_path: P,
420        extract_dir: Option<Q>,
421        config: &PluginConfig,
422        log_callback: Option<LogCallbackFn>,
423        verify_signatures: bool,
424        public_key_override: Option<&str>,
425    ) -> ConsumerResult<NativePlugin> {
426        let bundle_path = bundle_path.as_ref();
427        debug!("Loading bundle from: {}", bundle_path.display());
428
429        // Open and validate the bundle
430        let mut loader = BundleLoader::open(bundle_path)?;
431
432        // Check platform support
433        let platform = rustbridge_bundle::Platform::current().ok_or_else(|| {
434            ConsumerError::Bundle(rustbridge_bundle::BundleError::UnsupportedPlatform(
435                "Current platform is not supported".to_string(),
436            ))
437        })?;
438
439        if !loader.supports_current_platform() {
440            return Err(ConsumerError::Bundle(
441                rustbridge_bundle::BundleError::UnsupportedPlatform(
442                    "Current platform not supported by bundle".to_string(),
443                ),
444            ));
445        }
446
447        // Determine extraction directory
448        // When no explicit dir is provided, each load gets a unique directory to
449        // prevent SIGBUS from concurrent threads overwriting a mmap'd file.
450        let extract_dir_path: std::path::PathBuf = match extract_dir {
451            Some(dir) => dir.as_ref().to_path_buf(),
452            None => {
453                let instance_id = EXTRACT_INSTANCE.fetch_add(1, Ordering::Relaxed);
454                bundle_path
455                    .parent()
456                    .unwrap_or(Path::new("."))
457                    .join(".rustbridge-cache")
458                    .join(loader.manifest().plugin.name.as_str())
459                    .join(loader.manifest().plugin.version.as_str())
460                    .join(instance_id.to_string())
461            }
462        };
463
464        // Extract with or without verification
465        let lib_path = if verify_signatures {
466            loader.extract_library_verified(
467                platform,
468                &extract_dir_path,
469                true,
470                public_key_override,
471            )?
472        } else {
473            loader.extract_library_for_current_platform(&extract_dir_path)?
474        };
475
476        debug!("Extracted library to: {}", lib_path.display());
477
478        // Load the extracted library
479        Self::load_with_config(lib_path, config, log_callback)
480    }
481
482    /// Load a plugin by name, searching standard library paths.
483    ///
484    /// Searches for the library in:
485    /// 1. Current directory
486    /// 2. `./target/release`
487    /// 3. `./target/debug`
488    /// 4. System library paths (LD_LIBRARY_PATH on Linux, etc.)
489    ///
490    /// # Arguments
491    ///
492    /// * `name` - Library name without prefix/suffix (e.g., "myplugin" finds "libmyplugin.so")
493    ///
494    /// # Example
495    ///
496    /// ```ignore
497    /// // Searches for libmyplugin.so (Linux), libmyplugin.dylib (macOS), myplugin.dll (Windows)
498    /// let plugin = NativePluginLoader::load_by_name("myplugin")?;
499    /// ```
500    pub fn load_by_name(name: &str) -> ConsumerResult<NativePlugin> {
501        Self::load_by_name_with_config(name, &PluginConfig::default(), None)
502    }
503
504    /// Load a plugin by name with custom configuration.
505    pub fn load_by_name_with_config(
506        name: &str,
507        config: &PluginConfig,
508        log_callback: Option<LogCallbackFn>,
509    ) -> ConsumerResult<NativePlugin> {
510        let lib_name = library_filename(name);
511
512        // Search paths
513        let search_paths = [
514            std::path::PathBuf::from("."),
515            std::path::PathBuf::from("./target/release"),
516            std::path::PathBuf::from("./target/debug"),
517        ];
518
519        for search_path in &search_paths {
520            let full_path = search_path.join(&lib_name);
521            if full_path.exists() {
522                debug!("Found library at: {}", full_path.display());
523                return Self::load_with_config(full_path, config, log_callback);
524            }
525        }
526
527        // Try loading directly (system library paths)
528        debug!("Attempting to load '{}' from system paths", lib_name);
529        Self::load_with_config(&lib_name, config, log_callback)
530    }
531}
532
533/// Get the platform-specific library filename for a given name.
534///
535/// - Linux: `lib{name}.so`
536/// - macOS: `lib{name}.dylib`
537/// - Windows: `{name}.dll`
538fn library_filename(name: &str) -> String {
539    #[cfg(target_os = "linux")]
540    {
541        format!("lib{name}.so")
542    }
543    #[cfg(target_os = "macos")]
544    {
545        format!("lib{name}.dylib")
546    }
547    #[cfg(target_os = "windows")]
548    {
549        format!("{name}.dll")
550    }
551    #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
552    {
553        format!("lib{name}.so")
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    #![allow(non_snake_case)]
560    #![allow(clippy::unwrap_used)]
561
562    use super::*;
563    use std::ffi::CString;
564
565    #[test]
566    fn NativePluginLoader___load___nonexistent_library___returns_error() {
567        let result = NativePluginLoader::load("/nonexistent/library.so");
568
569        assert!(result.is_err());
570        let err = result.err().unwrap();
571        assert!(matches!(err, ConsumerError::LibraryLoad(_)));
572    }
573
574    #[test]
575    fn NativePluginLoader___load_bundle___nonexistent_bundle___returns_error() {
576        let result = NativePluginLoader::load_bundle("/nonexistent/bundle.rbp");
577
578        assert!(result.is_err());
579        let err = result.err().unwrap();
580        assert!(matches!(err, ConsumerError::Bundle(_)));
581    }
582
583    #[test]
584    fn NativePluginLoader___load_bundle_variant___nonexistent_bundle___returns_error() {
585        let result = NativePluginLoader::load_bundle_variant_with_config(
586            "/nonexistent/bundle.rbp",
587            "debug",
588            &PluginConfig::default(),
589            None,
590        );
591
592        assert!(result.is_err());
593        let err = result.err().unwrap();
594        assert!(matches!(err, ConsumerError::Bundle(_)));
595    }
596
597    #[test]
598    fn ffi_log_callback___no_callback_set___does_not_panic() {
599        // Clear any existing callback
600        set_log_callback(None);
601
602        // Create target as null-terminated C string, message as bytes with length
603        let target = CString::new("test").unwrap();
604        let message = b"test message";
605
606        // This should not panic
607        unsafe {
608            ffi_log_callback(2, target.as_ptr(), message.as_ptr(), message.len());
609        }
610    }
611
612    #[test]
613    fn ffi_log_callback___with_callback___invokes_callback() {
614        use std::sync::Arc;
615        use std::sync::atomic::{AtomicBool, Ordering};
616
617        let called = Arc::new(AtomicBool::new(false));
618        let called_clone = called.clone();
619
620        let callback: LogCallbackFn = Arc::new(move |level, target, message| {
621            assert_eq!(level, LogLevel::Info);
622            assert_eq!(target, "test");
623            assert_eq!(message, "test message");
624            called_clone.store(true, Ordering::SeqCst);
625        });
626
627        set_log_callback(Some(callback));
628
629        let target = CString::new("test").unwrap();
630        let message = b"test message";
631
632        unsafe {
633            ffi_log_callback(2, target.as_ptr(), message.as_ptr(), message.len());
634        }
635
636        assert!(called.load(Ordering::SeqCst));
637
638        // Clean up
639        set_log_callback(None);
640    }
641
642    #[test]
643    fn ffi_log_callback___null_pointers___uses_empty_strings() {
644        use std::sync::Arc;
645        use std::sync::atomic::{AtomicBool, Ordering};
646
647        let called = Arc::new(AtomicBool::new(false));
648        let called_clone = called.clone();
649
650        let callback: LogCallbackFn = Arc::new(move |_level, target, message| {
651            assert_eq!(target, "");
652            assert_eq!(message, "");
653            called_clone.store(true, Ordering::SeqCst);
654        });
655
656        set_log_callback(Some(callback));
657
658        unsafe {
659            ffi_log_callback(2, std::ptr::null(), std::ptr::null(), 0);
660        }
661
662        assert!(called.load(Ordering::SeqCst));
663
664        // Clean up
665        set_log_callback(None);
666    }
667}