extern crate proc_macro;
use proc_macro::TokenStream;
use syn::{parse_macro_input, parse_str, Ident, ImplItem, ImplItemMethod, ItemImpl, Type};
use quote::quote;
use std::collections::{HashSet, HashMap};
use std::sync::Mutex;
use proc_macro2::Span;
#[macro_use]
extern crate lazy_static;
lazy_static!{
static ref DEFAULT_TRAIT_IMPLS: Mutex<HashMap<String, DefaultTraitImpl>> = Mutex::new(HashMap::new());
}
struct DefaultTraitImpl {
pub trait_name: String,
pub methods: Vec<String>,
}
#[proc_macro_attribute]
pub fn default_trait_impl(_: TokenStream, input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as ItemImpl);
let pseudotrait = match *input.self_ty {
Type::Path(type_path) => {
match type_path.path.get_ident() {
Some(ident) => ident.to_string(),
None => return syntax_invalid_error(),
}
},
_ => return syntax_invalid_error(),
};
let trait_name = match input.trait_ {
Some(trait_tuple) => {
match trait_tuple.1.get_ident() {
Some(ident) => ident.to_string(),
None => return syntax_invalid_error(),
}
},
_ => return syntax_invalid_error(),
};
let methods: Vec<String> = input.items.iter().map(|method| {
return quote! {
#method
}.to_string()
}).collect();
DEFAULT_TRAIT_IMPLS.lock().unwrap().insert(pseudotrait, DefaultTraitImpl { trait_name, methods });
TokenStream::new()
}
fn syntax_invalid_error() -> TokenStream {
return quote! {
compile_error!("`default_trait_impl` expects to be given a syntactially valid trait implementation");
}.into()
}
#[proc_macro_attribute]
pub fn trait_impl(_: TokenStream, input: TokenStream) -> TokenStream {
let mut input = parse_macro_input!(input as ItemImpl);
let trait_name = match &input.trait_ {
Some(trait_tuple) => {
match trait_tuple.1.get_ident() {
Some(ident) => ident.to_string(),
None => return syntax_invalid_error(),
}
},
_ => return syntax_invalid_error(),
};
let mut methods = HashSet::new();
for item in &input.items {
if let ImplItem::Method(method) = item {
methods.insert(method.sig.ident.to_string());
}
}
match DEFAULT_TRAIT_IMPLS.lock().unwrap().get(&trait_name) {
Some(default_impl) => {
if let Some(trait_tuple) = &mut input.trait_ {
trait_tuple.1.segments[0].ident = Ident::new(&default_impl.trait_name, Span::call_site());
}
for default_impl_method in &default_impl.methods {
let parsed_default_method: ImplItemMethod = parse_str(default_impl_method).unwrap();
if !methods.contains(&parsed_default_method.sig.ident.to_string()) {
input.items.push(ImplItem::Method(parsed_default_method));
}
}
},
_ => return quote! {
compile_error!("`trait_impl` expects there to be a `default_trait_impl` for the trait it implements");
}.into()
}
let res = quote! {
#input
};
res.into()
}