use std::ffi::c_void;
use std::sync::Arc;
use libloading::Library;
use serde::de::DeserializeOwned;
use serde::Serialize;
use fidius_core::status::*;
use fidius_core::wire;
use fidius_core::PluginError;
use crate::error::CallError;
use crate::types::PluginInfo;
type FfiFn = unsafe extern "C" fn(*const u8, u32, *mut *mut u8, *mut u32) -> i32;
pub struct PluginHandle {
_library: Arc<Library>,
vtable: *const c_void,
free_buffer: Option<unsafe extern "C" fn(*mut u8, usize)>,
capabilities: u64,
method_count: u32,
info: PluginInfo,
}
unsafe impl Send for PluginHandle {}
unsafe impl Sync for PluginHandle {}
impl PluginHandle {
#[allow(dead_code)]
pub(crate) fn new(
library: Arc<Library>,
vtable: *const c_void,
free_buffer: Option<unsafe extern "C" fn(*mut u8, usize)>,
capabilities: u64,
method_count: u32,
info: PluginInfo,
) -> Self {
Self {
_library: library,
vtable,
free_buffer,
capabilities,
method_count,
info,
}
}
pub fn from_loaded(plugin: crate::loader::LoadedPlugin) -> Self {
Self {
_library: plugin.library,
vtable: plugin.vtable,
free_buffer: plugin.free_buffer,
capabilities: plugin.info.capabilities,
method_count: plugin.method_count,
info: plugin.info,
}
}
pub fn call_method<I: Serialize, O: DeserializeOwned>(
&self,
index: usize,
input: &I,
) -> Result<O, CallError> {
if index >= self.method_count as usize {
return Err(CallError::NotImplemented { bit: index as u32 });
}
let input_bytes =
wire::serialize(input).map_err(|e| CallError::Serialization(e.to_string()))?;
let fn_ptr = unsafe {
let fn_ptrs = self.vtable as *const FfiFn;
*fn_ptrs.add(index)
};
let mut out_ptr: *mut u8 = std::ptr::null_mut();
let mut out_len: u32 = 0;
let status = unsafe {
fn_ptr(
input_bytes.as_ptr(),
input_bytes.len() as u32,
&mut out_ptr,
&mut out_len,
)
};
match status {
STATUS_OK => {}
STATUS_BUFFER_TOO_SMALL => return Err(CallError::BufferTooSmall),
STATUS_SERIALIZATION_ERROR => {
return Err(CallError::Serialization("FFI serialization failed".into()))
}
STATUS_PLUGIN_ERROR => {
if !out_ptr.is_null() && out_len > 0 {
let output_slice =
unsafe { std::slice::from_raw_parts(out_ptr, out_len as usize) };
let plugin_err: PluginError = wire::deserialize(output_slice)
.map_err(|e| CallError::Deserialization(e.to_string()))?;
if let Some(free) = self.free_buffer {
unsafe { free(out_ptr, out_len as usize) };
}
return Err(CallError::Plugin(plugin_err));
}
return Err(CallError::Plugin(PluginError::new(
"UNKNOWN",
"plugin returned error but no error data",
)));
}
STATUS_PANIC => {
let msg = if !out_ptr.is_null() && out_len > 0 {
let slice = unsafe { std::slice::from_raw_parts(out_ptr, out_len as usize) };
let msg = wire::deserialize::<String>(slice)
.unwrap_or_else(|_| "unknown panic".into());
if let Some(free) = self.free_buffer {
unsafe { free(out_ptr, out_len as usize) };
}
msg
} else {
"unknown panic".into()
};
return Err(CallError::Panic(msg));
}
_ => return Err(CallError::UnknownStatus { code: status }),
}
if out_ptr.is_null() {
return Err(CallError::Serialization(
"plugin returned null output buffer".into(),
));
}
let output_slice = unsafe { std::slice::from_raw_parts(out_ptr, out_len as usize) };
let result: Result<O, CallError> =
wire::deserialize(output_slice).map_err(|e| CallError::Deserialization(e.to_string()));
if let Some(free) = self.free_buffer {
unsafe { free(out_ptr, out_len as usize) };
}
result
}
pub fn has_capability(&self, bit: u32) -> bool {
if bit >= 64 {
return false;
}
self.capabilities & (1u64 << bit) != 0
}
pub fn info(&self) -> &PluginInfo {
&self.info
}
}