use std::ffi::c_void;
use std::sync::Arc;
use libloading::Library;
use serde::de::DeserializeOwned;
use serde::Serialize;
use fidius_core::descriptor::{BufferStrategyKind, PluginDescriptor};
use fidius_core::status::*;
use fidius_core::wire;
use fidius_core::PluginError;
use crate::arena::{acquire_arena, grow_arena, release_arena, DEFAULT_ARENA_CAPACITY};
use crate::error::{CallError, LoadError};
use crate::types::PluginInfo;
type FfiFn = unsafe extern "C" fn(*const u8, u32, *mut *mut u8, *mut u32) -> i32;
type ArenaFn = unsafe extern "C" fn(*const u8, u32, *mut u8, u32, *mut u32, *mut u32) -> i32;
pub struct PluginHandle {
_library: Option<Arc<Library>>,
vtable: *const c_void,
descriptor: *const PluginDescriptor,
free_buffer: Option<unsafe extern "C" fn(*mut u8, usize)>,
capabilities: u64,
method_count: u32,
info: PluginInfo,
}
unsafe impl Send for PluginHandle {}
unsafe impl Sync for PluginHandle {}
impl PluginHandle {
#[allow(dead_code)]
pub(crate) fn new(
library: Arc<Library>,
vtable: *const c_void,
descriptor: *const PluginDescriptor,
free_buffer: Option<unsafe extern "C" fn(*mut u8, usize)>,
capabilities: u64,
method_count: u32,
info: PluginInfo,
) -> Self {
Self {
_library: Some(library),
vtable,
descriptor,
free_buffer,
capabilities,
method_count,
info,
}
}
pub fn from_loaded(plugin: crate::loader::LoadedPlugin) -> Self {
Self {
_library: Some(plugin.library),
vtable: plugin.vtable,
descriptor: plugin.descriptor,
free_buffer: plugin.free_buffer,
capabilities: plugin.info.capabilities,
method_count: plugin.method_count,
info: plugin.info,
}
}
pub fn from_descriptor(desc: &'static PluginDescriptor) -> Result<Self, LoadError> {
let info = PluginInfo {
name: unsafe { desc.plugin_name_str() }.to_string(),
interface_name: unsafe { desc.interface_name_str() }.to_string(),
interface_hash: desc.interface_hash,
interface_version: desc.interface_version,
capabilities: desc.capabilities,
buffer_strategy: desc
.buffer_strategy_kind()
.map_err(|v| LoadError::UnknownBufferStrategy { value: v })?,
runtime: crate::types::PluginRuntimeKind::Cdylib,
};
Ok(Self {
_library: None,
vtable: desc.vtable,
descriptor: desc as *const PluginDescriptor,
free_buffer: desc.free_buffer,
capabilities: desc.capabilities,
method_count: desc.method_count,
info,
})
}
pub fn find_in_process_descriptor(
plugin_name: &str,
) -> Result<&'static PluginDescriptor, LoadError> {
let reg = fidius_core::registry::get_registry();
for i in 0..reg.plugin_count as usize {
let desc_ptr = unsafe { *reg.descriptors.add(i) };
let desc = unsafe { &*desc_ptr };
if unsafe { desc.plugin_name_str() } == plugin_name {
return Ok(desc);
}
}
Err(LoadError::PluginNotFound {
name: plugin_name.to_string(),
})
}
pub fn call_method<I: Serialize, O: DeserializeOwned>(
&self,
index: usize,
input: &I,
) -> Result<O, CallError> {
if index >= self.method_count as usize {
return Err(CallError::InvalidMethodIndex {
index,
count: self.method_count,
});
}
let input_bytes =
wire::serialize(input).map_err(|e| CallError::Serialization(e.to_string()))?;
match self.info.buffer_strategy {
BufferStrategyKind::PluginAllocated => self.call_plugin_allocated(index, &input_bytes),
BufferStrategyKind::Arena => self.call_arena(index, &input_bytes),
}
}
pub fn call_method_raw(&self, index: usize, input: &[u8]) -> Result<Vec<u8>, CallError> {
if index >= self.method_count as usize {
return Err(CallError::InvalidMethodIndex {
index,
count: self.method_count,
});
}
match self.info.buffer_strategy {
BufferStrategyKind::PluginAllocated => self.call_plugin_allocated_raw(index, input),
BufferStrategyKind::Arena => self.call_arena_raw(index, input),
}
}
fn call_plugin_allocated<O: DeserializeOwned>(
&self,
index: usize,
input_bytes: &[u8],
) -> Result<O, CallError> {
let fn_ptr = unsafe {
let fn_ptrs = self.vtable as *const FfiFn;
*fn_ptrs.add(index)
};
let mut out_ptr: *mut u8 = std::ptr::null_mut();
let mut out_len: u32 = 0;
let status = unsafe {
fn_ptr(
input_bytes.as_ptr(),
input_bytes.len() as u32,
&mut out_ptr,
&mut out_len,
)
};
match status {
STATUS_OK => {}
STATUS_BUFFER_TOO_SMALL => return Err(CallError::BufferTooSmall),
STATUS_SERIALIZATION_ERROR => {
return Err(CallError::Serialization("FFI serialization failed".into()))
}
STATUS_PLUGIN_ERROR => {
if !out_ptr.is_null() && out_len > 0 {
let output_slice =
unsafe { std::slice::from_raw_parts(out_ptr, out_len as usize) };
let plugin_err: PluginError = wire::deserialize(output_slice)
.map_err(|e| CallError::Deserialization(e.to_string()))?;
if let Some(free) = self.free_buffer {
unsafe { free(out_ptr, out_len as usize) };
}
return Err(CallError::Plugin(plugin_err));
}
return Err(CallError::Plugin(PluginError::new(
"UNKNOWN",
"plugin returned error but no error data",
)));
}
STATUS_PANIC => {
let msg = if !out_ptr.is_null() && out_len > 0 {
let slice = unsafe { std::slice::from_raw_parts(out_ptr, out_len as usize) };
let msg = wire::deserialize::<String>(slice)
.unwrap_or_else(|_| "unknown panic".into());
if let Some(free) = self.free_buffer {
unsafe { free(out_ptr, out_len as usize) };
}
msg
} else {
"unknown panic".into()
};
return Err(CallError::Panic(msg));
}
_ => return Err(CallError::UnknownStatus { code: status }),
}
if out_ptr.is_null() {
return Err(CallError::Serialization(
"plugin returned null output buffer".into(),
));
}
let output_slice = unsafe { std::slice::from_raw_parts(out_ptr, out_len as usize) };
let result: Result<O, CallError> =
wire::deserialize(output_slice).map_err(|e| CallError::Deserialization(e.to_string()));
if let Some(free) = self.free_buffer {
unsafe { free(out_ptr, out_len as usize) };
}
result
}
fn call_arena<O: DeserializeOwned>(
&self,
index: usize,
input_bytes: &[u8],
) -> Result<O, CallError> {
let fn_ptr = unsafe {
let fn_ptrs = self.vtable as *const ArenaFn;
*fn_ptrs.add(index)
};
let mut arena = acquire_arena(DEFAULT_ARENA_CAPACITY);
let mut out_offset: u32 = 0;
let mut out_len: u32 = 0;
let mut retried = false;
let status = loop {
let s = unsafe {
fn_ptr(
input_bytes.as_ptr(),
input_bytes.len() as u32,
arena.as_mut_ptr(),
arena.len() as u32,
&mut out_offset,
&mut out_len,
)
};
if s == STATUS_BUFFER_TOO_SMALL && !retried {
let needed = out_len as usize;
grow_arena(&mut arena, needed);
retried = true;
continue;
}
break s;
};
match status {
STATUS_OK => {
let start = out_offset as usize;
let end = start + out_len as usize;
if end > arena.len() {
release_arena(arena);
return Err(CallError::Serialization(
"plugin reported out_offset/out_len outside arena".into(),
));
}
let result = wire::deserialize(&arena[start..end])
.map_err(|e| CallError::Deserialization(e.to_string()));
release_arena(arena);
result
}
STATUS_BUFFER_TOO_SMALL => {
release_arena(arena);
Err(CallError::BufferTooSmall)
}
STATUS_SERIALIZATION_ERROR => {
release_arena(arena);
Err(CallError::Serialization("FFI serialization failed".into()))
}
STATUS_PLUGIN_ERROR => {
let start = out_offset as usize;
let end = start + out_len as usize;
let plugin_err = if out_len > 0 && end <= arena.len() {
wire::deserialize::<PluginError>(&arena[start..end]).unwrap_or_else(|_| {
PluginError::new("UNKNOWN", "plugin returned malformed error")
})
} else {
PluginError::new("UNKNOWN", "plugin returned error but no error data")
};
release_arena(arena);
Err(CallError::Plugin(plugin_err))
}
STATUS_PANIC => {
release_arena(arena);
Err(CallError::Panic(
"plugin panicked (message not transmitted via Arena strategy)".into(),
))
}
code => {
release_arena(arena);
Err(CallError::UnknownStatus { code })
}
}
}
fn call_plugin_allocated_raw(
&self,
index: usize,
input_bytes: &[u8],
) -> Result<Vec<u8>, CallError> {
let fn_ptr = unsafe {
let fn_ptrs = self.vtable as *const FfiFn;
*fn_ptrs.add(index)
};
let mut out_ptr: *mut u8 = std::ptr::null_mut();
let mut out_len: u32 = 0;
let status = unsafe {
fn_ptr(
input_bytes.as_ptr(),
input_bytes.len() as u32,
&mut out_ptr,
&mut out_len,
)
};
match status {
STATUS_OK => {}
STATUS_BUFFER_TOO_SMALL => return Err(CallError::BufferTooSmall),
STATUS_SERIALIZATION_ERROR => {
return Err(CallError::Serialization("FFI serialization failed".into()))
}
STATUS_PLUGIN_ERROR => {
if !out_ptr.is_null() && out_len > 0 {
let output_slice =
unsafe { std::slice::from_raw_parts(out_ptr, out_len as usize) };
let plugin_err: PluginError = wire::deserialize(output_slice)
.map_err(|e| CallError::Deserialization(e.to_string()))?;
if let Some(free) = self.free_buffer {
unsafe { free(out_ptr, out_len as usize) };
}
return Err(CallError::Plugin(plugin_err));
}
return Err(CallError::Plugin(PluginError::new(
"UNKNOWN",
"plugin returned error but no error data",
)));
}
STATUS_PANIC => {
let msg = if !out_ptr.is_null() && out_len > 0 {
let slice = unsafe { std::slice::from_raw_parts(out_ptr, out_len as usize) };
let msg = wire::deserialize::<String>(slice)
.unwrap_or_else(|_| "unknown panic".into());
if let Some(free) = self.free_buffer {
unsafe { free(out_ptr, out_len as usize) };
}
msg
} else {
"unknown panic".into()
};
return Err(CallError::Panic(msg));
}
_ => return Err(CallError::UnknownStatus { code: status }),
}
if out_ptr.is_null() {
return Err(CallError::Serialization(
"plugin returned null output buffer".into(),
));
}
let output_slice = unsafe { std::slice::from_raw_parts(out_ptr, out_len as usize) };
let result = output_slice.to_vec();
if let Some(free) = self.free_buffer {
unsafe { free(out_ptr, out_len as usize) };
}
Ok(result)
}
fn call_arena_raw(&self, index: usize, input_bytes: &[u8]) -> Result<Vec<u8>, CallError> {
let fn_ptr = unsafe {
let fn_ptrs = self.vtable as *const ArenaFn;
*fn_ptrs.add(index)
};
let mut arena = acquire_arena(DEFAULT_ARENA_CAPACITY);
let mut out_offset: u32 = 0;
let mut out_len: u32 = 0;
let mut retried = false;
let status = loop {
let s = unsafe {
fn_ptr(
input_bytes.as_ptr(),
input_bytes.len() as u32,
arena.as_mut_ptr(),
arena.len() as u32,
&mut out_offset,
&mut out_len,
)
};
if s == STATUS_BUFFER_TOO_SMALL && !retried {
let needed = out_len as usize;
grow_arena(&mut arena, needed);
retried = true;
continue;
}
break s;
};
match status {
STATUS_OK => {
let start = out_offset as usize;
let end = start + out_len as usize;
if end > arena.len() {
release_arena(arena);
return Err(CallError::Serialization(
"plugin reported out_offset/out_len outside arena".into(),
));
}
let result = arena[start..end].to_vec();
release_arena(arena);
Ok(result)
}
STATUS_BUFFER_TOO_SMALL => {
release_arena(arena);
Err(CallError::BufferTooSmall)
}
STATUS_SERIALIZATION_ERROR => {
release_arena(arena);
Err(CallError::Serialization("FFI serialization failed".into()))
}
STATUS_PLUGIN_ERROR => {
let start = out_offset as usize;
let end = start + out_len as usize;
let plugin_err = if out_len > 0 && end <= arena.len() {
wire::deserialize::<PluginError>(&arena[start..end]).unwrap_or_else(|_| {
PluginError::new("UNKNOWN", "plugin returned malformed error")
})
} else {
PluginError::new("UNKNOWN", "plugin returned error but no error data")
};
release_arena(arena);
Err(CallError::Plugin(plugin_err))
}
STATUS_PANIC => {
release_arena(arena);
Err(CallError::Panic(
"plugin panicked (message not transmitted via Arena strategy)".into(),
))
}
code => {
release_arena(arena);
Err(CallError::UnknownStatus { code })
}
}
}
pub fn has_capability(&self, bit: u32) -> bool {
if bit >= 64 {
return false;
}
self.capabilities & (1u64 << bit) != 0
}
pub fn info(&self) -> &PluginInfo {
&self.info
}
pub fn method_metadata(&self, method_id: u32) -> Vec<(&str, &str)> {
if method_id >= self.method_count {
return Vec::new();
}
let desc = unsafe { &*self.descriptor };
if desc.method_metadata.is_null() {
return Vec::new();
}
let entries =
unsafe { std::slice::from_raw_parts(desc.method_metadata, self.method_count as usize) };
let entry = &entries[method_id as usize];
if entry.kvs.is_null() || entry.kv_count == 0 {
return Vec::new();
}
let kvs = unsafe { std::slice::from_raw_parts(entry.kvs, entry.kv_count as usize) };
kvs.iter()
.map(|kv| {
let k = unsafe { std::ffi::CStr::from_ptr(kv.key) }
.to_str()
.expect("metadata key is not valid UTF-8");
let v = unsafe { std::ffi::CStr::from_ptr(kv.value) }
.to_str()
.expect("metadata value is not valid UTF-8");
(k, v)
})
.collect()
}
pub fn trait_metadata(&self) -> Vec<(&str, &str)> {
let desc = unsafe { &*self.descriptor };
if desc.trait_metadata.is_null() || desc.trait_metadata_count == 0 {
return Vec::new();
}
let kvs = unsafe {
std::slice::from_raw_parts(desc.trait_metadata, desc.trait_metadata_count as usize)
};
kvs.iter()
.map(|kv| {
let k = unsafe { std::ffi::CStr::from_ptr(kv.key) }
.to_str()
.expect("trait metadata key is not valid UTF-8");
let v = unsafe { std::ffi::CStr::from_ptr(kv.value) }
.to_str()
.expect("trait metadata value is not valid UTF-8");
(k, v)
})
.collect()
}
}