use inflector::Inflector;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span, TokenStream as TokenStream2, TokenTree};
use quote::{quote, ToTokens};
use syn::{
parse::{Parse, ParseStream, Parser, Result},
FnArg, ItemEnum, Pat, PatIdent, PatType, Token,
};
mod kw {
syn::custom_keyword!(async_trait);
syn::custom_keyword!(Send);
}
#[derive(Debug)]
pub struct Item {
orig_enum: ItemEnum,
handler_trait_ident: Ident,
}
impl Parse for Item {
fn parse(input: ParseStream) -> Result<Self> {
let orig_enum = ItemEnum::parse(input)?;
let handler_trait_ident = Ident::new(
&format!("{}Handler", orig_enum.ident),
Span::call_site(),
);
Ok(Self {
orig_enum,
handler_trait_ident,
})
}
}
impl Item {
fn handler_trait_doc(&self) -> String {
format!(
"Handler of [`{0}`] variants.\n\nUsing [`{0}::dispatch_with`] \
method dispatches [`{0}`] variants to appropriate methods of \
this trait.",
self.orig_enum.ident
)
}
fn dispatch_with_method_doc(&self) -> String {
format!(
"Dispatches [`{0}`] with given [`{0}Handler`].",
self.orig_enum.ident
)
}
fn handler_trait(&self, args: &Args) -> TokenStream2 {
let self_kind = args.self_kind.clone();
let maybe_async = args.maybe_async_token();
let handler_trait_methods: Vec<_> = self
.orig_enum
.variants
.iter()
.map(|v| {
let fn_name_ident = syn::Ident::new(
&to_handler_fn_name(&v.ident.to_string()),
Span::call_site(),
);
let handler_fn_args = match v.fields {
syn::Fields::Named(ref fields) => {
let handler_fn_args: Vec<_> = fields
.named
.iter()
.map(|f| {
let ident = f.ident.as_ref().unwrap();
let ty = &f.ty;
quote! { #ident: #ty }
})
.collect();
quote! { #(#handler_fn_args),* }
}
syn::Fields::Unnamed(ref fields) => {
let handler_fn_args: Vec<_> = fields
.unnamed
.iter()
.map(|f| f.ty.clone())
.collect();
quote! { data: (#(#handler_fn_args),*) }
}
syn::Fields::Unit => quote! {},
};
let doc = format!(
"Handles [`{0}::{1}`] variant of [`{0}`].",
self.orig_enum.ident,
v.ident.to_string(),
);
quote! {
#[doc = #doc]
#maybe_async fn #fn_name_ident(
#self_kind,
#handler_fn_args
) -> Self::Output;
}
})
.collect();
let trait_doc = self.handler_trait_doc();
let handler_trait_ident = self.handler_trait_ident.clone();
let maybe_async_trait_macro = args.maybe_async_trait_macro();
let vis = self.orig_enum.vis.clone();
quote! {
#[automatically_derived]
#[allow(clippy::needless_arbitrary_self_type)]
#[doc = #trait_doc]
#maybe_async_trait_macro
#vis trait #handler_trait_ident {
type Output;
#(#handler_trait_methods)*
}
}
}
}
#[derive(Debug, PartialEq)]
struct IsLocal(bool);
impl Parse for IsLocal {
fn parse(input: ParseStream) -> Result<Self> {
if input.is_empty() {
Ok(Self(false))
} else {
let inner;
syn::parenthesized!(inner in input);
inner.parse::<Token![?]>()?;
inner.parse::<kw::Send>()?;
Ok(Self(true))
}
}
}
#[derive(Debug, PartialEq)]
pub struct Args {
self_kind: PatType,
async_trait: Option<IsLocal>,
}
impl Args {
fn maybe_async_trait_macro(&self) -> TokenStream2 {
match &self.async_trait {
None => {
quote! {}
}
Some(is_local) => {
if is_local.0 {
quote! {
#[async_trait::async_trait(?Send)]
}
} else {
quote! {
#[async_trait::async_trait]
}
}
}
}
}
fn maybe_await_token(&self) -> TokenStream2 {
if self.async_trait.is_some() {
quote! {.await}
} else {
quote! {}
}
}
fn maybe_async_token(&self) -> TokenStream2 {
if self.async_trait.is_some() {
quote! {async}
} else {
quote! {}
}
}
fn dispatch_with_handler_arg(&self) -> FnArg {
let mut handler_arg = self.self_kind.clone();
handler_arg.pat = Box::new(Pat::Ident(PatIdent {
attrs: Vec::new(),
by_ref: None,
mutability: None,
ident: Ident::new("handler", Span::call_site()),
subpat: None,
}));
let handler_arg: TokenStream2 = handler_arg
.to_token_stream()
.into_iter()
.map(|token| match &token {
TokenTree::Ident(ident) => {
if *ident == "Self" {
TokenTree::Ident(proc_macro2::Ident::new(
"T",
ident.span(),
))
} else {
token
}
}
_ => token,
})
.collect();
FnArg::parse.parse2(quote! {#handler_arg}).unwrap()
}
}
impl Default for Args {
fn default() -> Self {
let self_kind = FnArg::parse.parse2(quote! {self: &mut Self}).unwrap();
let self_kind = match self_kind {
FnArg::Typed(self_kind) => self_kind,
FnArg::Receiver(_) => unreachable!(),
};
Self {
self_kind,
async_trait: None,
}
}
}
impl Parse for Args {
fn parse(input: ParseStream) -> Result<Self> {
let mut args = Args::default();
if input.is_empty() {
return Ok(args);
}
if input.peek(Token![self]) && input.peek2(Token![:]) {
let self_kind = FnArg::parse(input)?;
let self_kind = match self_kind {
FnArg::Typed(self_kind) => self_kind,
FnArg::Receiver(_) => unreachable!(),
};
args.self_kind = self_kind;
}
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
if input.peek(kw::async_trait) {
input.parse::<kw::async_trait>()?;
args.async_trait = Some(IsLocal::parse(input)?);
}
Ok(args)
}
}
fn to_handler_fn_name(name: &str) -> String {
let mut snake_case = name.to_snake_case();
snake_case.insert_str(0, "on_");
snake_case
}
pub fn expand(item: Item, args: &Args) -> TokenStream {
let enum_ident = item.orig_enum.ident.clone();
let dispatch_variants: Vec<_> = item
.orig_enum
.variants
.iter()
.map(|v| {
let variant_ident = v.ident.clone();
let handler_fn_ident = syn::Ident::new(
&to_handler_fn_name(&variant_ident.to_string()),
Span::call_site(),
);
let fields: &Vec<_> = &v
.fields
.iter()
.enumerate()
.map(|(i, f)| {
f.ident.clone().unwrap_or_else(|| {
syn::Ident::new(&format!("f{}", i), Span::call_site())
})
})
.collect();
match v.fields {
syn::Fields::Named(_) => quote! {
#enum_ident::#variant_ident {#(#fields),*} => {
handler.#handler_fn_ident(#(#fields),*)
},
},
syn::Fields::Unnamed(_) => quote! {
#enum_ident::#variant_ident(#(#fields),*) => {
handler.#handler_fn_ident((#(#fields),*))
},
},
syn::Fields::Unit => quote! {
#enum_ident::#variant_ident => handler.#handler_fn_ident(),
},
}
})
.collect();
let handler_kind = args.dispatch_with_handler_arg();
let method_doc = item.dispatch_with_method_doc();
let handler_trait = item.handler_trait(&args);
let maybe_async = args.maybe_async_token();
let maybe_await = args.maybe_await_token();
let orig_enum = item.orig_enum;
let handler_trait_ident = item.handler_trait_ident;
TokenStream::from(quote! {
#orig_enum
#handler_trait
#[automatically_derived]
impl #enum_ident {
#[doc = #method_doc]
pub #maybe_async fn dispatch_with<T: #handler_trait_ident>(
self, #handler_kind,
) -> <T as #handler_trait_ident>::Output {
match self {
#(#dispatch_variants)*
}#maybe_await
}
}
})
}
#[cfg(test)]
mod to_handler_fn_name_spec {
use super::*;
#[test]
fn converts_name_from_camel_case_to_snake_case() {
for (name, expected) in &[
("SomeTestTrait", "on_some_test_trait"),
("RPCConnection", "on_rpc_connection"),
("RConnection", "on_r_connection"),
("RTCPeerConnection", "on_rtc_peer_connection"),
("testString", "on_test_string"),
("testtest", "on_testtest"),
("Some", "on_some"),
("S", "on_s"),
("s", "on_s"),
("ASDF", "on_asdf"),
] {
assert_eq!(to_handler_fn_name(name), *expected);
}
}
mod parse_args {
use super::*;
#[test]
fn empty() {
let args = Args::parse.parse2(quote! {}).unwrap();
assert_eq!(
args.dispatch_with_handler_arg(),
FnArg::parse.parse2(quote! {handler: &mut T}).unwrap()
);
assert!(args.async_trait.is_none());
assert_eq!(
FnArg::Typed(args.self_kind),
FnArg::parse.parse2(quote! {self: &mut Self}).unwrap()
);
}
#[test]
fn self_ref() {
let args = Args::parse.parse2(quote! {self: &Self}).unwrap();
assert_eq!(
args.dispatch_with_handler_arg(),
FnArg::parse.parse2(quote! {handler: &T}).unwrap()
);
assert!(args.async_trait.is_none());
assert_eq!(
FnArg::Typed(args.self_kind),
FnArg::parse.parse2(quote! {self: &Self}).unwrap()
);
}
#[test]
fn self_rc() {
let args = Args::parse
.parse2(quote! {self: std::rc::Rc<Self>})
.unwrap();
assert_eq!(
args.dispatch_with_handler_arg(),
FnArg::parse
.parse2(quote! {handler: std::rc::Rc<T>})
.unwrap()
);
assert!(args.async_trait.is_none());
assert_eq!(
FnArg::Typed(args.self_kind),
FnArg::parse
.parse2(quote! {self: std::rc::Rc<Self>})
.unwrap()
);
}
#[test]
fn async_trait_not_local() {
let args = Args::parse.parse2(quote! {async_trait}).unwrap();
assert_eq!(
args.dispatch_with_handler_arg(),
FnArg::parse.parse2(quote! {handler: &mut T}).unwrap()
);
assert!(!args.async_trait.unwrap().0);
assert_eq!(
FnArg::Typed(args.self_kind),
FnArg::parse.parse2(quote! {self: &mut Self}).unwrap()
);
}
#[test]
fn async_trait_local() {
let args = Args::parse.parse2(quote! {async_trait(?Send)}).unwrap();
assert_eq!(
args.dispatch_with_handler_arg(),
FnArg::parse.parse2(quote! {handler: &mut T}).unwrap()
);
assert!(args.async_trait.unwrap().0);
assert_eq!(
FnArg::Typed(args.self_kind),
FnArg::parse.parse2(quote! {self: &mut Self}).unwrap()
);
}
#[test]
fn self_arc_and_async_trait_not_send() {
let args = Args::parse
.parse2(quote! {self: Arc<Self>, async_trait})
.unwrap();
assert_eq!(
args.dispatch_with_handler_arg(),
FnArg::parse.parse2(quote! {handler: Arc<T>}).unwrap()
);
assert!(!args.async_trait.unwrap().0);
assert_eq!(
FnArg::Typed(args.self_kind),
FnArg::parse.parse2(quote! {self: Arc<Self>}).unwrap()
);
}
}
}