use std::ffi::c_void;
use std::path::Path;
use std::sync::Arc;
use fidius_core::descriptor::*;
use libloading::Library;
use crate::error::LoadError;
use crate::types::PluginInfo;
pub struct LoadedLibrary {
pub library: Arc<Library>,
pub plugins: Vec<LoadedPlugin>,
}
pub struct LoadedPlugin {
pub info: PluginInfo,
pub vtable: *const c_void,
pub free_buffer: Option<unsafe extern "C" fn(*mut u8, usize)>,
pub method_count: u32,
pub descriptor: *const PluginDescriptor,
pub library: Arc<Library>,
}
unsafe impl Send for LoadedPlugin {}
unsafe impl Sync for LoadedPlugin {}
impl std::fmt::Debug for LoadedPlugin {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LoadedPlugin")
.field("info", &self.info)
.field("vtable", &self.vtable)
.finish()
}
}
pub fn load_library(path: &Path) -> Result<LoadedLibrary, LoadError> {
let path_str = path.display().to_string();
#[cfg(feature = "tracing")]
tracing::debug!(path = %path_str, "loading library");
crate::arch::check_architecture(path)?;
let library = unsafe { Library::new(path) }.map_err(|e| {
if e.to_string().contains("No such file") || e.to_string().contains("not found") {
LoadError::LibraryNotFound {
path: path_str.clone(),
}
} else {
LoadError::LibLoading(e)
}
})?;
let get_registry: libloading::Symbol<unsafe extern "C" fn() -> *const PluginRegistry> =
unsafe { library.get(b"fidius_get_registry") }.map_err(|_| LoadError::SymbolNotFound {
path: path_str.clone(),
})?;
let registry = unsafe { &*get_registry() };
if registry.magic != FIDIUS_MAGIC {
return Err(LoadError::InvalidMagic);
}
if registry.registry_version != REGISTRY_VERSION {
return Err(LoadError::IncompatibleRegistryVersion {
got: registry.registry_version,
expected: REGISTRY_VERSION,
});
}
let library = Arc::new(library);
let mut plugins = Vec::with_capacity(registry.plugin_count as usize);
for i in 0..registry.plugin_count {
let desc = unsafe { &**registry.descriptors.add(i as usize) };
let plugin = validate_descriptor(desc, &library)?;
plugins.push(plugin);
}
Ok(LoadedLibrary { library, plugins })
}
fn validate_descriptor(
desc: &PluginDescriptor,
library: &Arc<Library>,
) -> Result<LoadedPlugin, LoadError> {
if desc.abi_version != ABI_VERSION {
return Err(LoadError::IncompatibleAbiVersion {
got: desc.abi_version,
expected: ABI_VERSION,
});
}
let interface_name = unsafe { desc.interface_name_str() }.to_string();
let plugin_name = unsafe { desc.plugin_name_str() }.to_string();
let info = PluginInfo {
name: plugin_name,
interface_name,
interface_hash: desc.interface_hash,
interface_version: desc.interface_version,
capabilities: desc.capabilities,
buffer_strategy: desc
.buffer_strategy_kind()
.map_err(|v| LoadError::UnknownBufferStrategy { value: v })?,
runtime: crate::types::PluginRuntimeKind::Cdylib,
};
Ok(LoadedPlugin {
info,
vtable: desc.vtable,
free_buffer: desc.free_buffer,
method_count: desc.method_count,
descriptor: desc as *const PluginDescriptor,
library: Arc::clone(library),
})
}
pub fn validate_against_interface(
plugin: &LoadedPlugin,
expected_hash: Option<u64>,
expected_strategy: Option<BufferStrategyKind>,
) -> Result<(), LoadError> {
if let Some(hash) = expected_hash {
if plugin.info.interface_hash != hash {
return Err(LoadError::InterfaceHashMismatch {
got: plugin.info.interface_hash,
expected: hash,
});
}
}
if let Some(strategy) = expected_strategy {
if plugin.info.buffer_strategy != strategy {
return Err(LoadError::BufferStrategyMismatch {
got: plugin.info.buffer_strategy,
expected: strategy,
});
}
}
Ok(())
}