multi_default_trait_impl/
lib.rs1extern crate proc_macro;
56use proc_macro::TokenStream;
57use syn::{parse_macro_input, parse_str, Ident, ImplItem, ImplItemMethod, ItemImpl, Type};
58use quote::quote;
59use std::collections::{HashSet, HashMap};
60use std::sync::Mutex;
61use proc_macro2::Span;
62
63#[macro_use]
64extern crate lazy_static;
65
66lazy_static!{
67 static ref DEFAULT_TRAIT_IMPLS: Mutex<HashMap<String, DefaultTraitImpl>> = Mutex::new(HashMap::new());
68}
69
70struct DefaultTraitImpl {
71 pub trait_name: String,
72 pub methods: Vec<String>,
73}
74
75#[proc_macro_attribute]
76pub fn default_trait_impl(_: TokenStream, input: TokenStream) -> TokenStream {
77 let input = parse_macro_input!(input as ItemImpl);
78
79 let pseudotrait = match *input.self_ty {
80 Type::Path(type_path) => {
81 match type_path.path.get_ident() {
82 Some(ident) => ident.to_string(),
83 None => return syntax_invalid_error(),
84 }
85 },
86 _ => return syntax_invalid_error(),
87 };
88
89 let trait_name = match input.trait_ {
90 Some(trait_tuple) => {
91 match trait_tuple.1.get_ident() {
92 Some(ident) => ident.to_string(),
93 None => return syntax_invalid_error(),
94 }
95 },
96 _ => return syntax_invalid_error(),
97 };
98
99 let methods: Vec<String> = input.items.iter().map(|method| {
100 return quote! {
101 #method
102 }.to_string()
103 }).collect();
104
105 DEFAULT_TRAIT_IMPLS.lock().unwrap().insert(pseudotrait, DefaultTraitImpl { trait_name, methods });
106
107 TokenStream::new()
108}
109
110fn syntax_invalid_error() -> TokenStream {
111 return quote! {
112 compile_error!("`default_trait_impl` expects to be given a syntactially valid trait implementation");
113 }.into()
114}
115
116#[proc_macro_attribute]
117pub fn trait_impl(_: TokenStream, input: TokenStream) -> TokenStream {
118 let mut input = parse_macro_input!(input as ItemImpl);
119
120 let trait_name = match &input.trait_ {
121 Some(trait_tuple) => {
122 match trait_tuple.1.get_ident() {
123 Some(ident) => ident.to_string(),
124 None => return syntax_invalid_error(),
125 }
126 },
127 _ => return syntax_invalid_error(),
128 };
129
130 let mut methods = HashSet::new();
131 for item in &input.items {
132 if let ImplItem::Method(method) = item {
133 methods.insert(method.sig.ident.to_string());
134 }
135 }
136
137 match DEFAULT_TRAIT_IMPLS.lock().unwrap().get(&trait_name) {
138 Some(default_impl) => {
139 if let Some(trait_tuple) = &mut input.trait_ {
140 trait_tuple.1.segments[0].ident = Ident::new(&default_impl.trait_name, Span::call_site());
141 }
142
143 for default_impl_method in &default_impl.methods {
144 let parsed_default_method: ImplItemMethod = parse_str(default_impl_method).unwrap();
145 if !methods.contains(&parsed_default_method.sig.ident.to_string()) {
146 input.items.push(ImplItem::Method(parsed_default_method));
147 }
148 }
149 },
150 _ => return quote! {
151 compile_error!("`trait_impl` expects there to be a `default_trait_impl` for the trait it implements");
152 }.into()
153 }
154
155 let res = quote! {
156 #input
157 };
158 res.into()
159}