use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{
parse::Parse, parse::ParseStream, FnArg, Ident, ImplItem, ItemImpl, LitStr, Pat, Path,
ReturnType, Token, Type,
};
use crate::ir::BufferStrategyAttr;
struct MethodInfo<'a> {
name: &'a Ident,
is_async: bool,
returns_result: bool,
arg_types: Vec<&'a Type>,
arg_names: Vec<Ident>,
}
fn is_result_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
type_path
.path
.segments
.last()
.map(|seg| seg.ident == "Result")
.unwrap_or(false)
} else {
false
}
}
pub struct PluginImplAttrs {
pub trait_name: Ident,
pub crate_path: Path,
pub buffer_strategy: BufferStrategyAttr,
}
impl Parse for PluginImplAttrs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let trait_name: Ident = input.parse()?;
let mut crate_path = None;
let mut buffer_strategy = None;
while !input.is_empty() {
let _comma: Token![,] = input.parse()?;
if input.peek(Token![crate]) {
let _kw: Token![crate] = input.parse()?;
let _eq: Token![=] = input.parse()?;
let lit: LitStr = input.parse()?;
let path: Path = lit.parse()?;
crate_path = Some(path);
} else {
let key: Ident = input.parse()?;
let _eq: Token![=] = input.parse()?;
match key.to_string().as_str() {
"buffer" => {
let value: Ident = input.parse()?;
buffer_strategy = Some(match value.to_string().as_str() {
"PluginAllocated" => BufferStrategyAttr::PluginAllocated,
"Arena" => BufferStrategyAttr::Arena,
_ => {
return Err(syn::Error::new(
value.span(),
"expected PluginAllocated or Arena",
))
}
});
}
other => {
return Err(syn::Error::new(
key.span(),
format!("unknown plugin_impl attribute `{other}`"),
));
}
}
}
}
let crate_path = crate_path.unwrap_or_else(|| syn::parse_str::<Path>("fidius").unwrap());
let buffer_strategy = buffer_strategy.unwrap_or(BufferStrategyAttr::PluginAllocated);
Ok(PluginImplAttrs {
trait_name,
crate_path,
buffer_strategy,
})
}
}
pub fn generate_plugin_impl(attrs: &PluginImplAttrs, item: &ItemImpl) -> syn::Result<TokenStream> {
let trait_name = &attrs.trait_name;
let impl_type = &item.self_ty;
let impl_type_str = quote!(#impl_type).to_string().replace(' ', "");
let impl_ident = format_ident!("{}", impl_type_str);
let impl_methods: Vec<MethodInfo> = item
.items
.iter()
.filter_map(|item| {
if let ImplItem::Fn(method) = item {
let returns_result = match &method.sig.output {
ReturnType::Type(_, ty) => is_result_type(ty),
ReturnType::Default => false,
};
let mut arg_types = Vec::new();
let mut arg_names = Vec::new();
for arg in &method.sig.inputs {
if let FnArg::Typed(pat_type) = arg {
arg_types.push(pat_type.ty.as_ref());
if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
arg_names.push(pat_ident.ident.clone());
} else {
arg_names.push(format_ident!("_arg"));
}
}
}
Some(MethodInfo {
name: &method.sig.ident,
is_async: method.sig.asyncness.is_some(),
returns_result,
arg_types,
arg_names,
})
} else {
None
}
})
.collect();
let method_names: Vec<&Ident> = impl_methods.iter().map(|m| m.name).collect();
let _has_async = impl_methods.iter().any(|m| m.is_async);
let crate_path = &attrs.crate_path;
let buffer_strategy = attrs.buffer_strategy;
let shims = generate_shims(&impl_ident, &impl_methods, crate_path, buffer_strategy);
let instance_name = format_ident!("__FIDIUS_INSTANCE_{}", impl_ident);
let instance = quote! {
static #instance_name: #impl_type = #impl_type;
};
let vtable = generate_vtable_static(trait_name, &impl_ident, &method_names);
let free_fn_name = format_ident!("__fidius_free_buffer_{}", impl_ident);
let free_buffer = match buffer_strategy {
BufferStrategyAttr::PluginAllocated => quote! {
unsafe extern "C" fn #free_fn_name(ptr: *mut u8, len: usize) {
if !ptr.is_null() && len > 0 {
unsafe {
let slice = std::slice::from_raw_parts_mut(ptr, len);
drop(Box::from_raw(slice as *mut [u8]));
}
}
}
},
BufferStrategyAttr::Arena => quote! {},
};
let descriptor = generate_descriptor(
trait_name,
&impl_ident,
&method_names,
crate_path,
buffer_strategy,
);
let registration = generate_inventory_registration(&impl_ident, crate_path);
Ok(quote! {
#item
#instance
#shims
#free_buffer
#vtable
#descriptor
#registration
})
}
fn generate_shims(
impl_ident: &Ident,
methods: &[MethodInfo],
crate_path: &Path,
buffer_strategy: BufferStrategyAttr,
) -> TokenStream {
let instance_name = format_ident!("__FIDIUS_INSTANCE_{}", impl_ident);
let shim_fns: Vec<TokenStream> = methods
.iter()
.map(|method| {
let method_name = method.name;
let shim_name = format_ident!("__fidius_shim_{}_{}", impl_ident, method_name);
let arg_types = &method.arg_types;
let arg_names = &method.arg_names;
let deserialize_args = quote! {
let (#(#arg_names,)*) = match #crate_path::wire::deserialize::<(#(#arg_types,)*)>(in_slice) {
Ok(v) => v,
Err(_) => return #crate_path::status::STATUS_SERIALIZATION_ERROR,
};
};
let method_call = if method.is_async {
quote! {
#crate_path::async_runtime::FIDIUS_RUNTIME.block_on(
#instance_name.#method_name(#(#arg_names),*)
)
}
} else {
quote! { #instance_name.#method_name(#(#arg_names),*) }
};
let output_handling = if method.returns_result {
quote! {
match output {
Ok(val) => {
match #crate_path::wire::serialize(&val) {
Ok(v) => (v, #crate_path::status::STATUS_OK),
Err(_) => return #crate_path::status::STATUS_SERIALIZATION_ERROR,
}
}
Err(err) => {
match #crate_path::wire::serialize(&err) {
Ok(v) => (v, #crate_path::status::STATUS_PLUGIN_ERROR),
Err(_) => return #crate_path::status::STATUS_SERIALIZATION_ERROR,
}
}
}
}
} else {
quote! {
match #crate_path::wire::serialize(&output) {
Ok(v) => (v, #crate_path::status::STATUS_OK),
Err(_) => return #crate_path::status::STATUS_SERIALIZATION_ERROR,
}
}
};
match buffer_strategy {
BufferStrategyAttr::Arena => quote! {
unsafe extern "C" fn #shim_name(
in_ptr: *const u8,
in_len: u32,
arena_ptr: *mut u8,
arena_cap: u32,
out_offset: *mut u32,
out_len: *mut u32,
) -> i32 {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let in_slice = unsafe { std::slice::from_raw_parts(in_ptr, in_len as usize) };
#deserialize_args
let output = #method_call;
let (output_bytes, status) = #output_handling;
if output_bytes.len() > arena_cap as usize {
unsafe {
*out_len = output_bytes.len() as u32;
}
return #crate_path::status::STATUS_BUFFER_TOO_SMALL;
}
let arena = unsafe {
::std::slice::from_raw_parts_mut(arena_ptr, arena_cap as usize)
};
arena[..output_bytes.len()].copy_from_slice(&output_bytes);
unsafe {
*out_offset = 0;
*out_len = output_bytes.len() as u32;
}
status
}));
match result {
Ok(status) => status,
Err(_panic_payload) => {
unsafe {
*out_offset = 0;
*out_len = 0;
}
#crate_path::status::STATUS_PANIC
}
}
}
},
BufferStrategyAttr::PluginAllocated => quote! {
unsafe extern "C" fn #shim_name(
in_ptr: *const u8,
in_len: u32,
out_ptr: *mut *mut u8,
out_len: *mut u32,
) -> i32 {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let in_slice = unsafe { std::slice::from_raw_parts(in_ptr, in_len as usize) };
#deserialize_args
let output = #method_call;
let (output_bytes, status) = #output_handling;
let boxed: Box<[u8]> = output_bytes.into_boxed_slice();
let len = boxed.len();
let ptr = Box::into_raw(boxed) as *mut u8;
unsafe {
*out_ptr = ptr;
*out_len = len as u32;
}
status
}));
match result {
Ok(status) => status,
Err(panic_payload) => {
let msg = panic_payload
.downcast_ref::<&str>()
.map(|s| s.to_string())
.or_else(|| panic_payload.downcast_ref::<String>().cloned())
.unwrap_or_else(|| "unknown panic".to_string());
if let Ok(msg_bytes) = #crate_path::wire::serialize(&msg) {
let boxed: Box<[u8]> = msg_bytes.into_boxed_slice();
let len = boxed.len();
let ptr = Box::into_raw(boxed) as *mut u8;
unsafe {
*out_ptr = ptr;
*out_len = len as u32;
}
}
#crate_path::status::STATUS_PANIC
}
}
}
},
}
})
.collect();
quote! { #(#shim_fns)* }
}
fn generate_vtable_static(
trait_name: &Ident,
impl_ident: &Ident,
methods: &[&Ident],
) -> TokenStream {
let companion = format_ident!("__fidius_{}", trait_name);
let vtable_type = format_ident!("{}_VTable", trait_name);
let vtable_name = format_ident!("__FIDIUS_VTABLE_{}", impl_ident);
let constructor = format_ident!("new_{}_vtable", trait_name.to_string().to_lowercase());
let shim_args: Vec<TokenStream> = methods
.iter()
.map(|method_name| {
let shim_name = format_ident!("__fidius_shim_{}_{}", impl_ident, method_name);
quote! { #shim_name }
})
.collect();
quote! {
static #vtable_name: #companion::#vtable_type = #companion::#constructor(#(#shim_args),*);
}
}
fn generate_descriptor(
trait_name: &Ident,
impl_ident: &Ident,
methods: &[&Ident],
crate_path: &Path,
buffer_strategy: BufferStrategyAttr,
) -> TokenStream {
let companion = format_ident!("__fidius_{}", trait_name);
let vtable_name = format_ident!("__FIDIUS_VTABLE_{}", impl_ident);
let descriptor_name = format_ident!("__FIDIUS_DESCRIPTOR_{}", impl_ident);
let free_fn_name = format_ident!("__fidius_free_buffer_{}", impl_ident);
let builder_fn = format_ident!(
"__fidius_build_{}_descriptor",
trait_name.to_string().to_lowercase()
);
let plugin_name_const = format_ident!("__FIDIUS_PLUGIN_NAME_{}", impl_ident);
let impl_name_str = impl_ident.to_string();
let optional_methods_ident = format_ident!("{}_OPTIONAL_METHODS", trait_name);
let method_strs: Vec<String> = methods.iter().map(|m| m.to_string()).collect();
let method_count = methods.len() as u32;
let free_buffer_expr = match buffer_strategy {
BufferStrategyAttr::PluginAllocated => quote! { Some(#free_fn_name) },
BufferStrategyAttr::Arena => quote! { None },
};
quote! {
const #plugin_name_const: &std::ffi::CStr = unsafe {
std::ffi::CStr::from_bytes_with_nul_unchecked(concat!(#impl_name_str, "\0").as_bytes())
};
static #descriptor_name: #crate_path::descriptor::PluginDescriptor = unsafe {
const CAPS: u64 = {
let optional = #companion::#optional_methods_ident;
let impl_methods: &[&str] = &[#(#method_strs),*];
let mut caps: u64 = 0;
let mut opt_idx = 0;
while opt_idx < optional.len() {
let opt_name = optional[opt_idx];
let mut impl_idx = 0;
while impl_idx < impl_methods.len() {
let impl_name = impl_methods[impl_idx];
if opt_name.len() == impl_name.len() {
let ob = opt_name.as_bytes();
let ib = impl_name.as_bytes();
let mut j = 0;
let mut eq = true;
while j < ob.len() {
if ob[j] != ib[j] { eq = false; }
j += 1;
}
if eq {
caps |= 1u64 << opt_idx;
}
}
impl_idx += 1;
}
opt_idx += 1;
}
caps
};
#companion::#builder_fn(
#plugin_name_const.as_ptr(),
&#vtable_name as *const _ as *const _,
CAPS,
#free_buffer_expr,
#method_count,
)
};
}
}
fn generate_inventory_registration(impl_ident: &Ident, crate_path: &Path) -> TokenStream {
let descriptor_name = format_ident!("__FIDIUS_DESCRIPTOR_{}", impl_ident);
quote! {
#crate_path::inventory::submit! {
#crate_path::registry::DescriptorEntry {
descriptor: &#descriptor_name,
}
}
}
}