extern crate proc_macro;
use std::borrow::Cow;
use std::collections::vec_deque::VecDeque;
use std::ops::Add;
use darling::FromMeta;
use proc_macro2::TokenStream;
use quote::{ToTokens, format_ident, quote};
use syn::visit::Visit;
use syn::{
AngleBracketedGenericArguments, Error, Expr, ExprLit, GenericArgument,
Ident, ItemFn, Lit, PatType, PathArguments, Result, ReturnType, Type,
TypePath,
};
struct FuncSignatureParser<'ast> {
args: Option<VecDeque<(String, &'ast Type)>>,
}
impl<'ast> FuncSignatureParser<'ast> {
fn new() -> Self {
Self { args: None }
}
#[inline(always)]
fn type_ident(type_path: &TypePath) -> &Ident {
&type_path.path.segments.last().unwrap().ident
}
#[inline(always)]
fn type_args(
type_path: &TypePath,
) -> Result<impl Iterator<Item = &GenericArgument>> {
if let PathArguments::AngleBracketed(
AngleBracketedGenericArguments { args, .. },
) = &type_path.path.segments.last().unwrap().arguments
{
Ok(args.into_iter())
} else {
Err(Error::new_spanned(type_path, "this type must have arguments"))
}
}
fn type_args_as_integers(
type_path: &TypePath,
error_msg: &str,
) -> Result<Vec<i64>> {
Self::type_args(type_path)?
.map(|arg| match arg {
GenericArgument::Const(Expr::Lit(ExprLit {
lit: Lit::Int(integer),
..
})) => integer.base10_parse(),
_ => Err(Error::new_spanned(type_path, error_msg)),
})
.collect::<Result<Vec<_>>>()
}
fn type_path_to_mangled_named(
type_path: &TypePath,
) -> Result<Cow<'static, str>> {
match Self::type_ident(type_path).to_string().as_str() {
"i32" | "i64" => Ok(Cow::Borrowed("i")),
"f32" | "f64" => Ok(Cow::Borrowed("f")),
"bool" => Ok(Cow::Borrowed("b")),
"PatternId" | "RuleId" => Ok(Cow::Borrowed("i")),
"RegexpId" => Ok(Cow::Borrowed("r")),
"Rc" => Ok(Cow::Borrowed("i")),
"RuntimeObjectHandle" => Ok(Cow::Borrowed("i")),
"RuntimeString" => Ok(Cow::Borrowed("s")),
"RangedInteger" => {
let error_msg = "RangedInteger must have MIN and MAX arguments (i.e: RangedInteger<0,256>)";
let args = Self::type_args_as_integers(type_path, error_msg)?;
let min = args
.first()
.ok_or_else(|| Error::new_spanned(type_path, error_msg))?;
let max = args
.get(1)
.ok_or_else(|| Error::new_spanned(type_path, error_msg))?;
Ok(Cow::Owned(format!("i:R{min:?}:{max:?}")))
}
"FixedLenString" => {
let error_msg = "FixedLenString must have a constant length (i.e: FixedLenString<32>)";
let args = Self::type_args_as_integers(type_path, error_msg)?;
let n = args
.first()
.ok_or_else(|| Error::new_spanned(type_path, error_msg))?;
Ok(Cow::Owned(format!("s:N{n:?}")))
}
"Lowercase" => {
let mut args = Self::type_args(type_path)?;
if let Some(GenericArgument::Type(Type::Path(p))) = args.next()
{
Ok(Self::type_path_to_mangled_named(p)?.add(":L"))
} else {
Err(Error::new_spanned(
type_path,
"Lowercase must have a type argument (i.e: <Lowercase<RuntimeString>>))",
))
}
}
"Uppercase" => {
let mut args = Self::type_args(type_path)?;
if let Some(GenericArgument::Type(Type::Path(p))) = args.next()
{
Ok(Self::type_path_to_mangled_named(p)?.add(":U"))
} else {
Err(Error::new_spanned(
type_path,
"Uppercase must have a type argument (i.e: <Uppercase<RuntimeString>>))",
))
}
}
type_ident => Err(Error::new_spanned(
type_path,
format!(
"type `{type_ident}` is not supported as argument or return type"
),
)),
}
}
fn mangled_type(ty: &Type) -> Result<Cow<'static, str>> {
match ty {
Type::Path(type_path) => {
if Self::type_ident(type_path) == "Option" {
if let PathArguments::AngleBracketed(angle_bracketed) =
&type_path.path.segments.last().unwrap().arguments
{
if let GenericArgument::Type(ty) =
angle_bracketed.args.first().unwrap()
{
Ok(Self::mangled_type(ty)?.add("u"))
} else {
unreachable!()
}
} else {
unreachable!()
}
} else {
Self::type_path_to_mangled_named(type_path)
}
}
Type::Group(group) => Self::mangled_type(group.elem.as_ref()),
Type::Tuple(tuple) => {
let mut result = String::new();
for elem in tuple.elems.iter() {
result.push_str(Self::mangled_type(elem)?.as_ref());
}
Ok(Cow::Owned(result))
}
_ => Err(Error::new_spanned(ty, "unsupported type")),
}
}
fn mangled_return_type(ty: &ReturnType) -> Result<Cow<'static, str>> {
match ty {
ReturnType::Default => Ok(Cow::Borrowed("")),
ReturnType::Type(_, ty) => Self::mangled_type(ty),
}
}
fn parse(&mut self, func: &'ast ItemFn) -> Result<String> {
self.args = Some(VecDeque::new());
for fn_arg in func.sig.inputs.iter() {
self.visit_fn_arg(fn_arg);
}
let mut args = self.args.take().unwrap();
let mut first_argument_is_ok = false;
if let Some((_, Type::Reference(ref_type))) = args.pop_front()
&& let Type::Path(type_) = ref_type.elem.as_ref()
{
first_argument_is_ok = Self::type_ident(type_) == "Caller";
}
if !first_argument_is_ok {
return Err(Error::new_spanned(
&func.sig,
format!(
"the first argument for function `{}` must be `&mut Caller<'_, ScanContext>`",
func.sig.ident
),
));
}
let mut mangled_name = String::from("@");
let mut first = true;
for (arg_name, arg_type) in args {
if !first {
mangled_name.push(',');
}
if !arg_name.is_empty() {
mangled_name.push_str(&arg_name);
mangled_name.push(':');
}
mangled_name.push_str(Self::mangled_type(arg_type)?.as_ref());
first = false;
}
mangled_name.push('@');
mangled_name.push_str(&Self::mangled_return_type(&func.sig.output)?);
Ok(mangled_name)
}
}
impl<'ast> Visit<'ast> for FuncSignatureParser<'ast> {
fn visit_pat_type(&mut self, pat_type: &'ast PatType) {
let name = if let syn::Pat::Ident(ident) = &*pat_type.pat {
ident.ident.to_string()
} else {
"".to_string()
};
self.args.as_mut().unwrap().push_back((name, pat_type.ty.as_ref()));
}
}
#[derive(Debug, FromMeta)]
pub struct WasmExportArgs {
name: Option<String>,
method_of: Option<String>,
sync: Option<String>,
#[darling(default)]
public: bool,
}
fn sync_flags_literal(
sync: Option<&str>,
default: &str,
) -> Result<TokenStream> {
let sync = sync.unwrap_or(default);
let bits = match sync {
"none" => 0_u32,
"before" => 1_u32,
"after" => 2_u32,
"both" => 3_u32,
_ => {
return Err(Error::new(
proc_macro2::Span::call_site(),
format!(
"invalid sync mode `{sync}`, expected one of: none, before, after, both"
),
));
}
};
Ok(quote! { #bits })
}
pub(crate) fn impl_wasm_export_macro(
attr_args: Vec<darling::ast::NestedMeta>,
func: ItemFn,
) -> Result<TokenStream> {
let attr_args = WasmExportArgs::from_list(attr_args.as_slice())?;
let rust_fn_name = &func.sig.ident;
if func.sig.inputs.is_empty() {
return Err(Error::new_spanned(
&func.sig,
format!(
"function `{rust_fn_name}` must have at least one argument of type `&mut Caller<'_, ScanContext>`"
),
));
}
let docs = func
.attrs
.iter()
.filter_map(|attr| {
if let Ok(name_value) = attr.meta.require_name_value()
&& let Ok(ident) = name_value.path.require_ident()
&& ident == "doc"
&& let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(doc_str),
..
}) = &name_value.value
{
Some(doc_str.value())
} else {
None
}
})
.collect::<Vec<String>>()
.join("\n");
let description = if docs.is_empty() {
quote! { None }
} else {
quote! { Some(std::borrow::Cow::Borrowed(#docs)) }
};
let fn_name = attr_args.name.unwrap_or(rust_fn_name.to_string());
let num_args = func.sig.inputs.len() - 1;
let public = attr_args.public;
let export_ident = format_ident!("export__{}", rust_fn_name);
let exported_fn_ident = format_ident!("WasmExportedFn{}", num_args);
let args_signature = FuncSignatureParser::new().parse(&func)?;
let sync_flags = sync_flags_literal(attr_args.sync.as_deref(), "both")?;
let method_of = attr_args
.method_of
.as_ref()
.map_or_else(|| quote! { None}, |m| quote! { Some(#m) });
let mangled_fn_name = if let Some(ty_name) = attr_args.method_of {
format!("{ty_name}::{fn_name}{args_signature}")
} else {
format!("{fn_name}{args_signature}")
};
let fn_descriptor = quote! {
#[allow(non_upper_case_globals)]
#[cfg_attr(not(feature = "inventory"), distributed_slice(WASM_EXPORTS))]
pub(crate) static #export_ident: WasmExport = WasmExport {
name: #fn_name,
mangled_name: #mangled_fn_name,
public: #public,
rust_module_path: module_path!(),
method_of: #method_of,
sync_flags: #sync_flags,
func: &#exported_fn_ident { target_fn: &#rust_fn_name },
description: #description,
};
#[cfg(feature = "inventory")]
inventory::submit! {
WasmExport {
name: #fn_name,
mangled_name: #mangled_fn_name,
public: #public,
rust_module_path: module_path!(),
method_of: #method_of,
sync_flags: #sync_flags,
func: &#exported_fn_ident { target_fn: &#rust_fn_name },
description: #description,
}
}
};
let mut token_stream = func.to_token_stream();
token_stream.extend(fn_descriptor);
Ok(token_stream)
}
#[cfg(test)]
mod tests {
use crate::wasm_export::FuncSignatureParser;
use syn::parse_quote;
#[test]
fn func_signature_parser() {
let mut parser = FuncSignatureParser::new();
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) { }
};
assert_eq!(parser.parse(&func).unwrap(), "@@");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) -> i32 { 0 }
};
assert_eq!(parser.parse(&func).unwrap(), "@@i");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) -> (i32, i32) { (0,0) }
};
assert_eq!(parser.parse(&func).unwrap(), "@@ii");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>, a: i32, b: i32) -> i32 { a + b }
};
assert_eq!(parser.parse(&func).unwrap(), "@a:i,b:i@i");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) -> Option<()> { None }
};
assert_eq!(parser.parse(&func).unwrap(), "@@u");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) -> Option<i64> { None }
};
assert_eq!(parser.parse(&func).unwrap(), "@@iu");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) -> Option<i64> { None }
};
assert_eq!(parser.parse(&func).unwrap(), "@@iu");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) -> Option<(i64, f64)> { None }
};
assert_eq!(parser.parse(&func).unwrap(), "@@ifu");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) { }
};
assert_eq!(parser.parse(&func).unwrap(), "@@");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) -> (i64, RuntimeString) { }
};
assert_eq!(parser.parse(&func).unwrap(), "@@is");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) -> Lowercase<RuntimeString> { }
};
assert_eq!(parser.parse(&func).unwrap(), "@@s:L");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) -> Uppercase<RuntimeString> { }
};
assert_eq!(parser.parse(&func).unwrap(), "@@s:U");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) -> Lowercase<FixedLenString<32>> { }
};
assert_eq!(parser.parse(&func).unwrap(), "@@s:N32:L");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) -> Uppercase<FixedLenString<32>> { }
};
assert_eq!(parser.parse(&func).unwrap(), "@@s:N32:U");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) -> FixedLenString<64> { }
};
assert_eq!(parser.parse(&func).unwrap(), "@@s:N64");
let func = parse_quote! {
fn foo(caller: &mut Caller<'_, ScanContext>) -> Option<Lowercase<FixedLenString<32>>> { }
};
assert_eq!(parser.parse(&func).unwrap(), "@@s:N32:Lu");
}
}