Skip to main content

rustbridge_consumer/
loader.rs

1//! Plugin loader for dynamically loading rustbridge plugins.
2
3use crate::error::{ConsumerError, ConsumerResult};
4use crate::ffi_bindings::{
5    FfiPluginHandle, LogCallback, PluginCallFn, PluginCallRawFn, PluginCreateFn,
6    PluginFreeBufferFn, PluginGetRejectedCountFn, PluginGetStateFn, PluginInitFn,
7    PluginSetLogLevelFn, PluginShutdownFn, RbResponseFreeFn,
8};
9use crate::plugin::NativePlugin;
10use libloading::Library;
11use rustbridge_bundle::BundleLoader;
12use rustbridge_core::{LogLevel, PluginConfig};
13use std::ffi::c_char;
14use std::path::Path;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicU64, Ordering};
17use tracing::debug;
18
19/// Monotonic counter to ensure each bundle extraction gets a unique directory.
20/// This prevents SIGBUS when multiple threads extract to the same path concurrently
21/// (one thread truncates the .so file while another has it mmap'd).
22static EXTRACT_INSTANCE: AtomicU64 = AtomicU64::new(0);
23
24/// Rust-friendly log callback type.
25///
26/// This callback receives log messages from the plugin.
27pub type LogCallbackFn = Arc<dyn Fn(LogLevel, &str, &str) + Send + Sync>;
28
29// Global log callback storage.
30// Uses a RwLock so the FFI callback can read from any thread while
31// set_log_callback writes only during plugin load/unload.
32static LOG_CALLBACK: std::sync::RwLock<Option<LogCallbackFn>> = std::sync::RwLock::new(None);
33
34/// Set the global log callback.
35fn set_log_callback(callback: Option<LogCallbackFn>) {
36    if let Ok(mut guard) = LOG_CALLBACK.write() {
37        *guard = callback;
38    }
39}
40
41/// FFI-compatible log callback that forwards to the Rust callback.
42///
43/// # Safety
44/// - `target` must be a valid null-terminated C string or null
45/// - `message` must be valid for `message_len` bytes or null
46unsafe extern "C" fn ffi_log_callback(
47    level: u8,
48    target: *const c_char,
49    message: *const u8,
50    message_len: usize,
51) {
52    let callback = LOG_CALLBACK.read().ok().and_then(|guard| guard.clone());
53    if let Some(callback) = callback {
54        let log_level = LogLevel::from_u8(level);
55
56        // SAFETY: target is a valid null-terminated C string
57        let target_str = if target.is_null() {
58            ""
59        } else {
60            unsafe { std::ffi::CStr::from_ptr(target) }
61                .to_str()
62                .unwrap_or("")
63        };
64
65        // SAFETY: message is valid for message_len bytes (NOT null-terminated)
66        let message_str = if message.is_null() || message_len == 0 {
67            ""
68        } else {
69            let bytes = unsafe { std::slice::from_raw_parts(message, message_len) };
70            std::str::from_utf8(bytes).unwrap_or("")
71        };
72
73        callback(log_level, target_str, message_str);
74    }
75}
76
77/// Loader for native plugins.
78///
79/// Provides methods to load plugins from shared libraries or bundles.
80pub struct NativePluginLoader;
81
82impl NativePluginLoader {
83    /// Load a plugin from a shared library path.
84    ///
85    /// Uses default configuration and no log callback.
86    ///
87    /// # Arguments
88    ///
89    /// * `path` - Path to the shared library (.so, .dylib, or .dll)
90    ///
91    /// # Example
92    ///
93    /// ```ignore
94    /// let plugin = NativePluginLoader::load("target/release/libmy_plugin.so")?;
95    /// ```
96    pub fn load<P: AsRef<Path>>(path: P) -> ConsumerResult<NativePlugin> {
97        Self::load_with_config(path, &PluginConfig::default(), None)
98    }
99
100    /// Load a plugin with custom configuration.
101    ///
102    /// # Arguments
103    ///
104    /// * `path` - Path to the shared library
105    /// * `config` - Plugin configuration
106    /// * `log_callback` - Optional callback to receive log messages
107    ///
108    /// # Example
109    ///
110    /// ```ignore
111    /// let config = PluginConfig::default();
112    /// let log_callback: LogCallbackFn = Arc::new(|level, target, msg| {
113    ///     println!("[{level}] {target}: {msg}");
114    /// });
115    ///
116    /// let plugin = NativePluginLoader::load_with_config(
117    ///     "target/release/libmy_plugin.so",
118    ///     &config,
119    ///     Some(log_callback),
120    /// )?;
121    /// ```
122    pub fn load_with_config<P: AsRef<Path>>(
123        path: P,
124        config: &PluginConfig,
125        log_callback: Option<LogCallbackFn>,
126    ) -> ConsumerResult<NativePlugin> {
127        let path = path.as_ref();
128        debug!("Loading plugin from: {}", path.display());
129
130        // Load the shared library
131        // SAFETY: We're loading a shared library which requires unsafe
132        let library = unsafe { Library::new(path) }?;
133
134        // Load required symbols
135        let plugin_create: PluginCreateFn = unsafe { *library.get(b"plugin_create\0")? };
136        let plugin_init: PluginInitFn = unsafe { *library.get(b"plugin_init\0")? };
137        let plugin_call: PluginCallFn = unsafe { *library.get(b"plugin_call\0")? };
138        let plugin_shutdown: PluginShutdownFn = unsafe { *library.get(b"plugin_shutdown\0")? };
139        let plugin_get_state: PluginGetStateFn = unsafe { *library.get(b"plugin_get_state\0")? };
140        let plugin_get_rejected_count: PluginGetRejectedCountFn =
141            unsafe { *library.get(b"plugin_get_rejected_count\0")? };
142        let plugin_set_log_level: PluginSetLogLevelFn =
143            unsafe { *library.get(b"plugin_set_log_level\0")? };
144        let plugin_free_buffer: PluginFreeBufferFn =
145            unsafe { *library.get(b"plugin_free_buffer\0")? };
146
147        // Load optional binary transport symbols
148        let plugin_call_raw: Option<PluginCallRawFn> =
149            unsafe { library.get(b"plugin_call_raw\0").ok().map(|s| *s) };
150        let rb_response_free: Option<RbResponseFreeFn> =
151            unsafe { library.get(b"rb_response_free\0").ok().map(|s| *s) };
152
153        // Set up log callback if provided.
154        // Only update the global when a real callback is given — loading a
155        // plugin without a callback must not clobber an existing one, since
156        // the FFI layer uses a single global callback for all instances.
157        if log_callback.is_some() {
158            set_log_callback(log_callback);
159        }
160        let ffi_callback: LogCallback = Some(ffi_log_callback);
161
162        // Create the plugin instance
163        // SAFETY: plugin_create returns a valid pointer or null
164        let plugin_ptr = unsafe { plugin_create() };
165        if plugin_ptr.is_null() {
166            return Err(ConsumerError::NullHandle);
167        }
168
169        // Serialize config to JSON
170        let config_json = serde_json::to_vec(config)?;
171
172        // Initialize the plugin
173        // SAFETY: plugin_ptr is valid, config_json is valid for its length
174        let handle: FfiPluginHandle = unsafe {
175            plugin_init(
176                plugin_ptr,
177                config_json.as_ptr(),
178                config_json.len(),
179                ffi_callback,
180            )
181        };
182
183        if handle.is_null() {
184            return Err(ConsumerError::NullHandle);
185        }
186
187        debug!("Plugin initialized with handle: {:?}", handle);
188
189        // SAFETY: All pointers are valid and came from the library
190        Ok(unsafe {
191            NativePlugin::new(
192                library,
193                handle,
194                plugin_call,
195                plugin_call_raw,
196                plugin_shutdown,
197                plugin_get_state,
198                plugin_get_rejected_count,
199                plugin_set_log_level,
200                plugin_free_buffer,
201                rb_response_free,
202            )
203        })
204    }
205
206    /// Load a plugin from a bundle file.
207    ///
208    /// Extracts the library for the current platform and loads it.
209    ///
210    /// # Arguments
211    ///
212    /// * `bundle_path` - Path to the .rbp bundle file
213    ///
214    /// # Example
215    ///
216    /// ```ignore
217    /// let plugin = NativePluginLoader::load_bundle("my-plugin-1.0.0.rbp")?;
218    /// ```
219    pub fn load_bundle<P: AsRef<Path>>(bundle_path: P) -> ConsumerResult<NativePlugin> {
220        Self::load_bundle_with_config(bundle_path, &PluginConfig::default(), None)
221    }
222
223    /// Load a plugin from a bundle with custom configuration.
224    ///
225    /// # Arguments
226    ///
227    /// * `bundle_path` - Path to the .rbp bundle file
228    /// * `config` - Plugin configuration
229    /// * `log_callback` - Optional callback to receive log messages
230    ///
231    /// # Example
232    ///
233    /// ```ignore
234    /// let config = PluginConfig::default();
235    /// let plugin = NativePluginLoader::load_bundle_with_config(
236    ///     "my-plugin-1.0.0.rbp",
237    ///     &config,
238    ///     None,
239    /// )?;
240    /// ```
241    pub fn load_bundle_with_config<P: AsRef<Path>>(
242        bundle_path: P,
243        config: &PluginConfig,
244        log_callback: Option<LogCallbackFn>,
245    ) -> ConsumerResult<NativePlugin> {
246        let bundle_path = bundle_path.as_ref();
247        debug!("Loading bundle from: {}", bundle_path.display());
248
249        // Open and validate the bundle
250        let mut loader = BundleLoader::open(bundle_path)?;
251
252        // Check platform support
253        if !loader.supports_current_platform() {
254            return Err(ConsumerError::Bundle(
255                rustbridge_bundle::BundleError::UnsupportedPlatform(
256                    "Current platform not supported by bundle".to_string(),
257                ),
258            ));
259        }
260
261        // Each load gets a unique extraction directory to prevent SIGBUS from
262        // concurrent threads overwriting a file that another thread has mmap'd.
263        let instance_id = EXTRACT_INSTANCE.fetch_add(1, Ordering::Relaxed);
264        let extract_dir = bundle_path
265            .parent()
266            .unwrap_or(Path::new("."))
267            .join(".rustbridge-cache")
268            .join(loader.manifest().plugin.name.as_str())
269            .join(loader.manifest().plugin.version.as_str())
270            .join(instance_id.to_string());
271
272        // Extract the library for the current platform
273        let lib_path = loader.extract_library_for_current_platform(&extract_dir)?;
274
275        debug!("Extracted library to: {}", lib_path.display());
276
277        // Load the extracted library
278        Self::load_with_config(lib_path, config, log_callback)
279    }
280
281    /// Load a specific variant from a bundle with custom configuration.
282    ///
283    /// Unlike `load_bundle_with_config` which always extracts the default (release) variant,
284    /// this method extracts the named variant (e.g., "debug", "release").
285    ///
286    /// # Arguments
287    ///
288    /// * `bundle_path` - Path to the .rbp bundle file
289    /// * `variant` - Variant name (e.g., "release", "debug")
290    /// * `config` - Plugin configuration
291    /// * `log_callback` - Optional callback to receive log messages
292    ///
293    /// # Example
294    ///
295    /// ```ignore
296    /// let config = PluginConfig::default();
297    /// let plugin = NativePluginLoader::load_bundle_variant_with_config(
298    ///     "my-plugin-1.0.0.rbp",
299    ///     "debug",
300    ///     &config,
301    ///     None,
302    /// )?;
303    /// ```
304    pub fn load_bundle_variant_with_config<P: AsRef<Path>>(
305        bundle_path: P,
306        variant: &str,
307        config: &PluginConfig,
308        log_callback: Option<LogCallbackFn>,
309    ) -> ConsumerResult<NativePlugin> {
310        let bundle_path = bundle_path.as_ref();
311        debug!(
312            "Loading bundle variant '{}' from: {}",
313            variant,
314            bundle_path.display()
315        );
316
317        // Open and validate the bundle
318        let mut loader = BundleLoader::open(bundle_path)?;
319
320        // Check platform support
321        let platform = rustbridge_bundle::Platform::current().ok_or_else(|| {
322            ConsumerError::Bundle(rustbridge_bundle::BundleError::UnsupportedPlatform(
323                "Current platform is not supported".to_string(),
324            ))
325        })?;
326
327        if !loader.supports_current_platform() {
328            return Err(ConsumerError::Bundle(
329                rustbridge_bundle::BundleError::UnsupportedPlatform(
330                    "Current platform not supported by bundle".to_string(),
331                ),
332            ));
333        }
334
335        // Each load gets a unique extraction directory to prevent SIGBUS from
336        // concurrent threads overwriting a file that another thread has mmap'd.
337        let instance_id = EXTRACT_INSTANCE.fetch_add(1, Ordering::Relaxed);
338        let extract_dir = bundle_path
339            .parent()
340            .unwrap_or(Path::new("."))
341            .join(".rustbridge-cache")
342            .join(loader.manifest().plugin.name.as_str())
343            .join(loader.manifest().plugin.version.as_str())
344            .join(format!("{variant}-{instance_id}"));
345
346        // Extract the specified variant
347        let lib_path = loader.extract_library_variant(platform, variant, &extract_dir)?;
348
349        debug!("Extracted variant library to: {}", lib_path.display());
350
351        // Load the extracted library
352        Self::load_with_config(lib_path, config, log_callback)
353    }
354
355    /// Load a plugin from a bundle to a specific extraction directory.
356    ///
357    /// This is useful when you want to control where the library is extracted.
358    ///
359    /// # Arguments
360    ///
361    /// * `bundle_path` - Path to the .rbp bundle file
362    /// * `extract_dir` - Directory to extract the library to
363    /// * `config` - Plugin configuration
364    /// * `log_callback` - Optional callback to receive log messages
365    pub fn load_bundle_to_dir<P: AsRef<Path>, Q: AsRef<Path>>(
366        bundle_path: P,
367        extract_dir: Q,
368        config: &PluginConfig,
369        log_callback: Option<LogCallbackFn>,
370    ) -> ConsumerResult<NativePlugin> {
371        Self::load_bundle_verified(
372            bundle_path,
373            Some(extract_dir),
374            config,
375            log_callback,
376            false,
377            None,
378        )
379    }
380
381    /// Load a plugin from a bundle with signature verification.
382    ///
383    /// # Arguments
384    ///
385    /// * `bundle_path` - Path to the .rbp bundle file
386    /// * `config` - Plugin configuration
387    /// * `log_callback` - Optional callback to receive log messages
388    /// * `verify_signatures` - Whether to verify minisign signatures
389    /// * `public_key_override` - Optional public key to use instead of manifest's key
390    ///
391    /// # Example
392    ///
393    /// ```ignore
394    /// // Load with signature verification (recommended for production)
395    /// let plugin = NativePluginLoader::load_bundle_with_verification(
396    ///     "my-plugin-1.0.0.rbp",
397    ///     &PluginConfig::default(),
398    ///     None,
399    ///     true,  // verify signatures
400    ///     None,  // use manifest's public key
401    /// )?;
402    /// ```
403    pub fn load_bundle_with_verification<P: AsRef<Path>>(
404        bundle_path: P,
405        config: &PluginConfig,
406        log_callback: Option<LogCallbackFn>,
407        verify_signatures: bool,
408        public_key_override: Option<&str>,
409    ) -> ConsumerResult<NativePlugin> {
410        Self::load_bundle_verified(
411            bundle_path,
412            None::<&Path>,
413            config,
414            log_callback,
415            verify_signatures,
416            public_key_override,
417        )
418    }
419
420    /// Internal method to load a bundle with all options.
421    fn load_bundle_verified<P: AsRef<Path>, Q: AsRef<Path>>(
422        bundle_path: P,
423        extract_dir: Option<Q>,
424        config: &PluginConfig,
425        log_callback: Option<LogCallbackFn>,
426        verify_signatures: bool,
427        public_key_override: Option<&str>,
428    ) -> ConsumerResult<NativePlugin> {
429        let bundle_path = bundle_path.as_ref();
430        debug!("Loading bundle from: {}", bundle_path.display());
431
432        // Open and validate the bundle
433        let mut loader = BundleLoader::open(bundle_path)?;
434
435        // Check platform support
436        let platform = rustbridge_bundle::Platform::current().ok_or_else(|| {
437            ConsumerError::Bundle(rustbridge_bundle::BundleError::UnsupportedPlatform(
438                "Current platform is not supported".to_string(),
439            ))
440        })?;
441
442        if !loader.supports_current_platform() {
443            return Err(ConsumerError::Bundle(
444                rustbridge_bundle::BundleError::UnsupportedPlatform(
445                    "Current platform not supported by bundle".to_string(),
446                ),
447            ));
448        }
449
450        // Determine extraction directory
451        // When no explicit dir is provided, each load gets a unique directory to
452        // prevent SIGBUS from concurrent threads overwriting a mmap'd file.
453        let extract_dir_path: std::path::PathBuf = match extract_dir {
454            Some(dir) => dir.as_ref().to_path_buf(),
455            None => {
456                let instance_id = EXTRACT_INSTANCE.fetch_add(1, Ordering::Relaxed);
457                bundle_path
458                    .parent()
459                    .unwrap_or(Path::new("."))
460                    .join(".rustbridge-cache")
461                    .join(loader.manifest().plugin.name.as_str())
462                    .join(loader.manifest().plugin.version.as_str())
463                    .join(instance_id.to_string())
464            }
465        };
466
467        // Extract with or without verification
468        let lib_path = if verify_signatures {
469            loader.extract_library_verified(
470                platform,
471                &extract_dir_path,
472                true,
473                public_key_override,
474            )?
475        } else {
476            loader.extract_library_for_current_platform(&extract_dir_path)?
477        };
478
479        debug!("Extracted library to: {}", lib_path.display());
480
481        // Load the extracted library
482        Self::load_with_config(lib_path, config, log_callback)
483    }
484
485    /// Load a plugin by name, searching standard library paths.
486    ///
487    /// Searches for the library in:
488    /// 1. Current directory
489    /// 2. `./target/release`
490    /// 3. `./target/debug`
491    /// 4. System library paths (LD_LIBRARY_PATH on Linux, etc.)
492    ///
493    /// # Arguments
494    ///
495    /// * `name` - Library name without prefix/suffix (e.g., "myplugin" finds "libmyplugin.so")
496    ///
497    /// # Example
498    ///
499    /// ```ignore
500    /// // Searches for libmyplugin.so (Linux), libmyplugin.dylib (macOS), myplugin.dll (Windows)
501    /// let plugin = NativePluginLoader::load_by_name("myplugin")?;
502    /// ```
503    pub fn load_by_name(name: &str) -> ConsumerResult<NativePlugin> {
504        Self::load_by_name_with_config(name, &PluginConfig::default(), None)
505    }
506
507    /// Load a plugin by name with custom configuration.
508    pub fn load_by_name_with_config(
509        name: &str,
510        config: &PluginConfig,
511        log_callback: Option<LogCallbackFn>,
512    ) -> ConsumerResult<NativePlugin> {
513        let lib_name = library_filename(name);
514
515        // Search paths relative to CWD
516        let mut search_paths = vec![
517            std::path::PathBuf::from("."),
518            std::path::PathBuf::from("./target/release"),
519            std::path::PathBuf::from("./target/debug"),
520        ];
521
522        // During `cargo test`, CARGO_MANIFEST_DIR points to the crate root.
523        // Walk up to find the workspace target directory so load_by_name works
524        // regardless of which crate's tests are running.
525        if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
526            let manifest_path = std::path::PathBuf::from(manifest_dir);
527            for ancestor in manifest_path.ancestors().skip(1) {
528                let release = ancestor.join("target").join("release");
529                if release.is_dir() {
530                    search_paths.push(release);
531                    search_paths.push(ancestor.join("target").join("debug"));
532                    break;
533                }
534            }
535        }
536
537        for search_path in &search_paths {
538            let full_path = search_path.join(&lib_name);
539            if full_path.exists() {
540                debug!("Found library at: {}", full_path.display());
541                return Self::load_with_config(full_path, config, log_callback);
542            }
543        }
544
545        // Try loading directly (system library paths)
546        debug!("Attempting to load '{}' from system paths", lib_name);
547        Self::load_with_config(&lib_name, config, log_callback)
548    }
549}
550
551/// Get the platform-specific library filename for a given name.
552///
553/// - Linux: `lib{name}.so`
554/// - macOS: `lib{name}.dylib`
555/// - Windows: `{name}.dll`
556fn library_filename(name: &str) -> String {
557    #[cfg(target_os = "linux")]
558    {
559        format!("lib{name}.so")
560    }
561    #[cfg(target_os = "macos")]
562    {
563        format!("lib{name}.dylib")
564    }
565    #[cfg(target_os = "windows")]
566    {
567        format!("{name}.dll")
568    }
569    #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
570    {
571        format!("lib{name}.so")
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    #![allow(non_snake_case)]
578    #![allow(clippy::unwrap_used)]
579
580    use super::*;
581    use std::ffi::CString;
582
583    #[test]
584    fn NativePluginLoader___load___nonexistent_library___returns_error() {
585        let result = NativePluginLoader::load("/nonexistent/library.so");
586
587        assert!(result.is_err());
588        let err = result.err().unwrap();
589        assert!(matches!(err, ConsumerError::LibraryLoad(_)));
590    }
591
592    #[test]
593    fn NativePluginLoader___load_bundle___nonexistent_bundle___returns_error() {
594        let result = NativePluginLoader::load_bundle("/nonexistent/bundle.rbp");
595
596        assert!(result.is_err());
597        let err = result.err().unwrap();
598        assert!(matches!(err, ConsumerError::Bundle(_)));
599    }
600
601    #[test]
602    fn NativePluginLoader___load_bundle_variant___nonexistent_bundle___returns_error() {
603        let result = NativePluginLoader::load_bundle_variant_with_config(
604            "/nonexistent/bundle.rbp",
605            "debug",
606            &PluginConfig::default(),
607            None,
608        );
609
610        assert!(result.is_err());
611        let err = result.err().unwrap();
612        assert!(matches!(err, ConsumerError::Bundle(_)));
613    }
614
615    #[test]
616    fn ffi_log_callback___no_callback_set___does_not_panic() {
617        // Clear any existing callback
618        set_log_callback(None);
619
620        // Create target as null-terminated C string, message as bytes with length
621        let target = CString::new("test").unwrap();
622        let message = b"test message";
623
624        // This should not panic
625        unsafe {
626            ffi_log_callback(2, target.as_ptr(), message.as_ptr(), message.len());
627        }
628    }
629
630    #[test]
631    fn ffi_log_callback___with_callback___invokes_callback() {
632        use std::sync::Arc;
633        use std::sync::atomic::{AtomicBool, Ordering};
634
635        let called = Arc::new(AtomicBool::new(false));
636        let called_clone = called.clone();
637
638        let callback: LogCallbackFn = Arc::new(move |level, target, message| {
639            assert_eq!(level, LogLevel::Info);
640            assert_eq!(target, "test");
641            assert_eq!(message, "test message");
642            called_clone.store(true, Ordering::SeqCst);
643        });
644
645        set_log_callback(Some(callback));
646
647        let target = CString::new("test").unwrap();
648        let message = b"test message";
649
650        unsafe {
651            ffi_log_callback(2, target.as_ptr(), message.as_ptr(), message.len());
652        }
653
654        assert!(called.load(Ordering::SeqCst));
655
656        // Clean up
657        set_log_callback(None);
658    }
659
660    #[test]
661    fn ffi_log_callback___null_pointers___uses_empty_strings() {
662        use std::sync::Arc;
663        use std::sync::atomic::{AtomicBool, Ordering};
664
665        let called = Arc::new(AtomicBool::new(false));
666        let called_clone = called.clone();
667
668        let callback: LogCallbackFn = Arc::new(move |_level, target, message| {
669            assert_eq!(target, "");
670            assert_eq!(message, "");
671            called_clone.store(true, Ordering::SeqCst);
672        });
673
674        set_log_callback(Some(callback));
675
676        unsafe {
677            ffi_log_callback(2, std::ptr::null(), std::ptr::null(), 0);
678        }
679
680        assert!(called.load(Ordering::SeqCst));
681
682        // Clean up
683        set_log_callback(None);
684    }
685}