1extern crate proc_macro2;
134
135use proc_macro::TokenStream;
136use std::collections::HashMap;
137
138use proc_macro2::TokenStream as TokenStream2;
139use syn::visit_mut::VisitMut;
140use syn::{
141 parse2, parse_macro_input, Attribute, Fields, FnArg, Ident, ImplItem, Item, ItemEnum, ItemImpl,
142 ItemMod, ItemTrait, Pat, PatType, Path, PathArguments, Signature, TraitItem, Type, Variant,
143};
144
145use quote::quote;
146
147const IMPL_ATTR: &str = "implement";
148const EXT_ATTR: &str = "external";
149
150fn attr_idx(attrs: &[Attribute], ident: &str) -> Option<usize> {
151 (0..attrs.len()).find(|idx| attrs[*idx].path.is_ident(ident))
152}
153
154fn pop_attr(attrs: &mut Vec<Attribute>, ident: &str) -> Option<Attribute> {
155 attr_idx(attrs, ident).map(|idx| attrs.remove(idx))
156}
157
158fn find_attr<'a>(attrs: &'a [Attribute], ident: &str) -> Option<&'a Attribute> {
159 attr_idx(&attrs, ident).map(|idx| &attrs[idx])
160}
161
162fn gen_static_method_call(receiver: TokenStream2, signature: &Signature) -> TokenStream2 {
163 let method_ident = &signature.ident;
164
165 let args = signature
166 .inputs
167 .iter()
168 .skip(1) .map(|a| match a {
170 FnArg::Typed(PatType { pat, .. }) => match &**pat {
171 Pat::Ident(ident) => &ident.ident,
172 other => panic!("unsupported pattern in parameter: `{}`", quote! { #other }),
173 },
174 _ => panic!("parameter binding must be an identifier"),
175 });
176
177 quote! { #receiver::#method_ident(__self #(, #args)*) }
178}
179
180struct WrapperVariant {
181 variant: Variant,
182 wrapped: Type,
183}
184
185impl From<Variant> for WrapperVariant {
186 fn from(variant: Variant) -> Self {
187 match &variant.fields {
188 Fields::Unnamed(a) if a.unnamed.len() == 1 => WrapperVariant {
189 variant: variant.clone(),
190 wrapped: a.unnamed.first().unwrap().ty.clone(),
191 },
192 _ => panic!("expected a variant with a single unnamed value"),
193 }
194 }
195}
196
197fn gen_match_block(
198 variants: &[WrapperVariant],
199 action: impl Fn(&WrapperVariant) -> TokenStream2,
200) -> TokenStream2 {
201 let branches = variants
202 .iter()
203 .map(|variant| {
204 let action = action(&variant);
205 let ident = &variant.variant.ident;
206 quote! { Self::#ident(__self) => #action }
207 })
208 .collect::<Vec<_>>();
209
210 quote! {
211 match self {
212 #(#branches),*
213 }
214 }
215}
216
217fn has_self_param(sig: &Signature) -> bool {
218 sig.inputs
219 .first()
220 .map(|param| match param {
221 FnArg::Receiver(..) => true,
222 FnArg::Typed(PatType { pat, .. }) => match &**pat {
223 Pat::Ident(ident) => &ident.ident.to_string() == "self",
224 _ => false,
225 },
226 })
227 .unwrap_or(false)
228}
229
230fn implement_trait(
232 trait_decl: &ItemTrait,
233 variants: &[WrapperVariant],
234 pseudo_impl: &mut ItemImpl,
235) {
236 assert!(pseudo_impl.items.is_empty());
237
238 let trait_ident = &trait_decl.ident;
239
240 let proxy_methods = trait_decl.items.iter().map(|i| match i {
241 TraitItem::Method(i) => {
242 let sig = &i.sig;
243 if !has_self_param(sig) {
244 match &i.default {
245 Some(..) => return parse2(quote! { #i }).unwrap(),
246 None => panic!(
247 "`{}` has no self parameter or default implementation",
248 quote! { #sig }
249 ),
250 }
251 }
252
253 let match_block = gen_match_block(variants, |_| gen_static_method_call(quote! { #trait_ident }, sig));
254 let tokens = quote! { #sig { #match_block } };
255 parse2::<ImplItem>(tokens).unwrap()
256 }
257 _ => panic!(
258 "impl block annotated with `#[{}]` may only contain methods",
259 IMPL_ATTR
260 ),
261 });
262
263 pseudo_impl.items = proxy_methods.collect();
264}
265
266fn implement_raw(variants: &[WrapperVariant], pseudo_impl: &mut ItemImpl) {
268 pseudo_impl
269 .items
270 .iter_mut()
271 .flat_map(|i| match i {
272 ImplItem::Method(method) => pop_attr(&mut method.attrs, IMPL_ATTR).map(|_| method),
273 _ => None,
274 })
275 .for_each(|mut method| {
276 if !method.block.stmts.is_empty() {
277 panic!("method annotated with `#[{}]` must be empty", IMPL_ATTR)
278 }
279
280 let match_block = gen_match_block(variants, |variant| {
281 let ty = &variant.wrapped;
282 gen_static_method_call(quote! { #ty }, &method.sig)
283 });
284 let body = quote! { { #match_block } };
285 method.block = syn::parse2(body).unwrap();
286 });
287}
288
289struct GenerateProxyImpl {
290 proxy_enum: Ident,
291 variants: Option<Vec<WrapperVariant>>,
292 trait_defs: HashMap<String, ItemTrait>,
293}
294
295impl GenerateProxyImpl {
296 fn new(proxy_enum: Ident) -> Self {
297 GenerateProxyImpl {
298 proxy_enum,
299 variants: None,
300 trait_defs: HashMap::new(),
301 }
302 }
303
304 fn get_variants(&self) -> &[WrapperVariant] {
305 self.variants
306 .as_ref()
307 .unwrap_or_else(|| panic!("proxy enum must be defined first"))
308 .as_slice()
309 }
310
311 fn store_trait_decl(&mut self, attr: Option<Path>, decl: ItemTrait) {
312 let mut path = match attr {
313 Some(path) => quote! { #path },
314 None => {
315 let ident = &decl.ident;
316 quote! { #ident }
317 }
318 }
319 .to_string();
320 path.retain(|c| !c.is_whitespace());
321 self.trait_defs.insert(path, decl);
322 }
323
324 fn get_trait_decl(&self, mut path: Path) -> &ItemTrait {
325 path.segments
326 .iter_mut()
327 .for_each(|seg| seg.arguments = PathArguments::None);
328 let mut path = quote! { #path }.to_string();
329 path.retain(|c| !c.is_whitespace());
330
331 self.trait_defs
332 .get(&path)
333 .unwrap_or_else(|| panic!("missing declaration of trait `{}`", path))
334 }
335
336 fn impl_from_variants(&self, module: &mut ItemMod) {
337 let proxy_enum = &self.proxy_enum;
338 for WrapperVariant { variant, wrapped, .. } in self.get_variants() {
339 let variant = &variant.ident;
340 let tokens = quote! {
341 impl From<#wrapped> for #proxy_enum {
342 fn from(from: #wrapped) -> Self {
343 #proxy_enum :: #variant(from)
344 }
345 }
346 };
347 let from_impl: ItemImpl = syn::parse2(tokens).unwrap();
348 module.content.as_mut().unwrap().1.push(from_impl.into());
349 }
350 }
351}
352
353impl VisitMut for GenerateProxyImpl {
354 fn visit_item_enum_mut(&mut self, i: &mut ItemEnum) {
356 if i.ident != self.proxy_enum {
357 return;
358 }
359 assert!(self.variants.is_none());
360
361 self.variants = Some(
362 i.variants
363 .iter()
364 .cloned()
365 .map(WrapperVariant::from)
366 .collect(),
367 );
368 }
369
370 fn visit_item_impl_mut(&mut self, impl_block: &mut ItemImpl) {
371 match impl_block.trait_.as_mut() {
372 None => implement_raw(self.get_variants(), impl_block),
374 Some((_, path, _)) => {
376 if pop_attr(&mut impl_block.attrs, IMPL_ATTR).is_some() {
377 implement_trait(
378 self.get_trait_decl(path.clone()),
379 self.get_variants(),
380 impl_block,
381 );
382 }
383 }
384 };
385 }
386
387 fn visit_item_mod_mut(&mut self, module: &mut ItemMod) {
388 syn::visit_mut::visit_item_mod_mut(self, module);
389 module.content.as_mut().unwrap().1.retain(|item| {
391 if let Item::Trait(ItemTrait { attrs, .. }) = item {
392 find_attr(&attrs, EXT_ATTR).is_none()
393 } else {
394 true
395 }
396 });
397 self.impl_from_variants(module);
398 }
399
400 fn visit_item_trait_mut(&mut self, trait_def: &mut ItemTrait) {
402 let ext_attr = find_attr(&trait_def.attrs, EXT_ATTR).map(|attr| attr.parse_args().unwrap());
403 self.store_trait_decl(ext_attr, trait_def.clone());
404 }
405}
406
407#[proc_macro_attribute]
408pub fn proxy(attr: TokenStream, item: TokenStream) -> TokenStream {
409 let mut module = parse_macro_input!(item as ItemMod);
410 let proxy_enum = parse_macro_input!(attr as Ident);
411
412 GenerateProxyImpl::new(proxy_enum).visit_item_mod_mut(&mut module);
413
414 TokenStream::from(quote! { #module })
415}