use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{ItemTrait, TraitItem};
use crate::ir::{BufferStrategyAttr, InterfaceIR};
fn strip_optional_attrs(item: &ItemTrait) -> ItemTrait {
let mut cleaned = item.clone();
for trait_item in &mut cleaned.items {
if let TraitItem::Fn(method) = trait_item {
method
.attrs
.retain(|attr| !attr.path().is_ident("optional"));
}
}
cleaned
}
pub fn generate_interface(ir: &InterfaceIR) -> syn::Result<TokenStream> {
match ir.attrs.buffer_strategy {
BufferStrategyAttr::CallerAllocated => {
return Err(syn::Error::new_spanned(
&ir.original_trait.ident,
"CallerAllocated buffer strategy is not yet supported",
));
}
BufferStrategyAttr::Arena => {
return Err(syn::Error::new_spanned(
&ir.original_trait.ident,
"Arena buffer strategy is not yet supported",
));
}
BufferStrategyAttr::PluginAllocated => {}
}
let cleaned_trait = strip_optional_attrs(&ir.original_trait);
let vtable = generate_vtable(ir);
let constants = generate_constants(ir);
let descriptor_builder = generate_descriptor_builder(ir);
let method_indices = generate_method_indices(ir);
let companion_mod = format_ident!("__fidius_{}", ir.trait_name);
Ok(quote! {
#cleaned_trait
#[allow(non_snake_case, non_upper_case_globals, dead_code)]
pub mod #companion_mod {
use super::*;
#vtable
#constants
#method_indices
#descriptor_builder
}
})
}
fn generate_vtable(ir: &InterfaceIR) -> TokenStream {
let vtable_name = format_ident!("{}_VTable", ir.trait_name);
let fields: Vec<TokenStream> = ir
.methods
.iter()
.map(|m| {
let field_name = &m.name;
let fn_type = quote! {
unsafe extern "C" fn(
*const u8, u32,
*mut *mut u8, *mut u32,
) -> i32
};
if m.optional_since.is_some() {
quote! { pub #field_name: Option<#fn_type> }
} else {
quote! { pub #field_name: #fn_type }
}
})
.collect();
let constructor_name = format_ident!("new_{}_vtable", ir.trait_name.to_string().to_lowercase());
let fn_type = quote! {
unsafe extern "C" fn(*const u8, u32, *mut *mut u8, *mut u32) -> i32
};
let params: Vec<TokenStream> = ir
.methods
.iter()
.map(|m| {
let name = &m.name;
quote! { #name: #fn_type }
})
.collect();
let field_assigns: Vec<TokenStream> = ir
.methods
.iter()
.map(|m| {
let name = &m.name;
if m.optional_since.is_some() {
quote! { #name: Some(#name) }
} else {
quote! { #name: #name }
}
})
.collect();
quote! {
#[repr(C)]
pub struct #vtable_name {
#(#fields,)*
}
pub const fn #constructor_name(#(#params),*) -> #vtable_name {
#vtable_name {
#(#field_assigns,)*
}
}
}
}
fn generate_constants(ir: &InterfaceIR) -> TokenStream {
let trait_name = &ir.trait_name;
let required_sigs: Vec<&str> = ir
.methods
.iter()
.filter(|m| m.is_required())
.map(|m| m.signature_string.as_str())
.collect();
let hash_value = fidius_core::hash::interface_hash(&required_sigs);
let hash_name = format_ident!("{}_INTERFACE_HASH", trait_name);
let version_name = format_ident!("{}_INTERFACE_VERSION", trait_name);
let strategy_name = format_ident!("{}_BUFFER_STRATEGY", trait_name);
let version_val = ir.attrs.version;
let strategy_val = ir.attrs.buffer_strategy as u8;
let cap_constants: Vec<TokenStream> = ir
.methods
.iter()
.filter(|m| m.optional_since.is_some())
.enumerate()
.map(|(bit, m)| {
let const_name =
format_ident!("{}_CAP_{}", trait_name, m.name.to_string().to_uppercase());
let bit_val = 1u64 << bit;
quote! { pub const #const_name: u64 = #bit_val; }
})
.collect();
let optional_names_ident = format_ident!("{}_OPTIONAL_METHODS", trait_name);
let optional_names: Vec<String> = ir
.methods
.iter()
.filter(|m| m.optional_since.is_some())
.map(|m| m.name.to_string())
.collect();
quote! {
pub const #hash_name: u64 = #hash_value;
pub const #version_name: u32 = #version_val;
pub const #strategy_name: u8 = #strategy_val;
#(#cap_constants)*
pub const #optional_names_ident: &[&str] = &[#(#optional_names),*];
}
}
fn generate_descriptor_builder(ir: &InterfaceIR) -> TokenStream {
let trait_name = &ir.trait_name;
let crate_path = &ir.attrs.crate_path;
let vtable_name = format_ident!("{}_VTable", trait_name);
let fn_name = format_ident!(
"__fidius_build_{}_descriptor",
trait_name.to_string().to_lowercase()
);
let hash_name = format_ident!("{}_INTERFACE_HASH", trait_name);
let version_name = format_ident!("{}_INTERFACE_VERSION", trait_name);
let strategy_name = format_ident!("{}_BUFFER_STRATEGY", trait_name);
let interface_name_str = trait_name.to_string();
let interface_name_cstr_ident = format_ident!("__FIDIUS_INTERFACE_NAME_{}", trait_name);
quote! {
const #interface_name_cstr_ident: &std::ffi::CStr = {
unsafe { std::ffi::CStr::from_bytes_with_nul_unchecked(concat!(#interface_name_str, "\0").as_bytes()) }
};
pub const unsafe fn #fn_name(
plugin_name: *const std::ffi::c_char,
vtable: *const #vtable_name,
capabilities: u64,
free_buffer: Option<unsafe extern "C" fn(*mut u8, usize)>,
method_count: u32,
) -> #crate_path::descriptor::PluginDescriptor {
#crate_path::descriptor::PluginDescriptor {
abi_version: #crate_path::descriptor::ABI_VERSION,
interface_name: #interface_name_cstr_ident.as_ptr(),
interface_hash: #hash_name,
interface_version: #version_name,
capabilities,
wire_format: #crate_path::wire::WIRE_FORMAT as u8,
buffer_strategy: #strategy_name,
plugin_name,
vtable: vtable as *const std::ffi::c_void,
free_buffer,
method_count,
}
}
}
}
fn generate_method_indices(ir: &InterfaceIR) -> TokenStream {
let indices: Vec<TokenStream> = ir
.methods
.iter()
.enumerate()
.map(|(i, m)| {
let const_name = format_ident!("METHOD_{}", m.name.to_string().to_uppercase());
let doc = format!("Vtable index for `{}`.", m.name);
quote! {
#[doc = #doc]
pub const #const_name: usize = #i;
}
})
.collect();
quote! { #(#indices)* }
}
#[allow(dead_code)]
fn _generate_client_deferred(ir: &InterfaceIR) -> TokenStream {
let trait_name = &ir.trait_name;
let client_name = format_ident!("{}Client", trait_name);
let methods: Vec<TokenStream> = ir
.methods
.iter()
.enumerate()
.map(|(i, m)| {
let method_name = &m.name;
let index = i;
let arg_types = &m.arg_types;
let arg_names = &m.arg_names;
let ret_type = match &m.return_type {
Some(ty) => quote! { #ty },
None => quote! { () },
};
let cap_check = if m.optional_since.is_some() {
let cap_bit = ir.methods.iter()
.filter(|mm| mm.optional_since.is_some())
.position(|mm| mm.name == m.name)
.unwrap_or(0) as u32;
quote! {
if !self.handle.has_capability(#cap_bit) {
return Err(fidius_host::CallError::NotImplemented { bit: #cap_bit });
}
}
} else {
quote! {}
};
if arg_types.len() == 1 {
let arg_type = &arg_types[0];
let arg_name = &arg_names[0];
quote! {
pub fn #method_name(&self, #arg_name: &#arg_type) -> Result<#ret_type, fidius_host::CallError> {
#cap_check
self.handle.call_method(#index, #arg_name)
}
}
} else if arg_types.is_empty() {
quote! {
pub fn #method_name(&self) -> Result<#ret_type, fidius_host::CallError> {
#cap_check
self.handle.call_method(#index, &())
}
}
} else {
quote! {
pub fn #method_name(&self, #(#arg_names: &#arg_types),*) -> Result<#ret_type, fidius_host::CallError> {
#cap_check
self.handle.call_method(#index, &(#(#arg_names.clone()),*))
}
}
}
})
.collect();
quote! {
pub struct #client_name {
handle: fidius_host::PluginHandle,
}
impl #client_name {
pub fn from_handle(handle: fidius_host::PluginHandle) -> Self {
Self { handle }
}
pub fn handle(&self) -> &fidius_host::PluginHandle {
&self.handle
}
#(#methods)*
}
}
}