proxy_enum/
lib.rs

1//! Emulate dynamic dispatch and ["sealed classes"](https://kotlinlang.org/docs/reference/sealed-classes.html) using a proxy enum, which defers all method calls to its variants.
2//!
3//! # Introduction
4//! In rust, dynamic dispatch is done using trait objects (`dyn Trait`).
5//! They enable us to have runtime polymorphism, a way of expressing that a type implements a
6//! certain trait while ignoring its concrete implementation.
7//!
8//! ```
9//! let animal: &dyn Animal = random_animal();
10//! animal.feed(); // may print "mew", "growl" or "squeak"
11//! ```
12//!
13//! Trait objects come with a downside though:
14//! getting a concrete implementation back from a trait object (downcasting) is painfull.
15//! (see [std::any::Any])
16//!
17//! If you know there are only a finite number of implentations to work with, an `enum` might be
18//! better at expressing such a relationship:
19//! ```
20//! enum Animal {
21//!     Cat(Cat),
22//!     Lion(Lion),
23//!     Mouse(Mouse)
24//! }
25//!
26//! match random_animal() {
27//!     Animal::Cat(cat) => cat.feed(),
28//!     Animal::Lion(lion) => lion.feed(),
29//!     Animal::Mouse(mouse) => mouse.feed()
30//! }
31//! ```
32//! Some languages have special support for such types, like Kotlin with so called "sealed classes".
33//!
34//! Rust, however, does *not*.
35//!
36//! `proxy-enum` simplifies working with such types using procedural macros.
37//!
38//! # Usage
39//! ```
40//! #[proxy_enum::proxy(Animal)]
41//! mod proxy {
42//!     enum Animal {
43//!         Cat(Cat),
44//!         Lion(Lion),
45//!         Mouse(Mouse)
46//!     }
47//!
48//!     impl Animal {
49//!         #[implement]
50//!         fn feed(&self) {}
51//!     }
52//! }
53//! ```
54//! This will expand to:
55//! ```
56//! mod proxy {
57//!     enum Animal {
58//!         Cat(Cat),
59//!         Lion(Lion),
60//!         Mouse(Mouse)
61//!     }
62//!
63//!     impl Animal {
64//!         fn feed(&self) {
65//!             match self {
66//!                 Animal::Cat(cat) => cat.feed(),
67//!                 Animal::Lion(lion) => lion.feed(),
68//!                 Animal::Mouse(mouse) => mouse.feed()
69//!             }
70//!         }
71//!     }
72//!     
73//!     impl From<Cat> for Animal {
74//!         fn from(from: Cat) -> Self {
75//!             Animal::Cat(from)
76//!         }
77//!     }
78//!
79//!     impl From<Lion> for Animal {
80//!         fn from(from: Lion) -> Self {
81//!             Animal::Lion(from)
82//!         }
83//!     }
84//!
85//!     impl From<Mouse> for Animal {
86//!         fn from(from: Mouse) -> Self {
87//!             Animal::Mouse(from)
88//!         }
89//!     }
90//! }
91//! ```
92//! This, however, will only compile if `Cat`, `Lion` and `Mouse` all have a method called `feed`.
93//! Since rust has traits to express common functionality, trait implentations can be generated too:
94//! ```
95//! #[proxy_enum::proxy(Animal)]
96//! mod proxy {
97//!     enum Animal {
98//!         Cat(Cat),
99//!         Lion(Lion),
100//!         Mouse(Mouse)
101//!     }
102//!
103//!     trait Eat {
104//!         fn feed(&self);
105//!     }
106//!
107//!     #[implement]
108//!     impl Eat for Animal {}
109//! }
110//! ```
111//! Since the macro has to know which methods the trait contains, it has to be defined within the
112//! module. However, implementations for external traits can be generated too:
113//!
114//! ```
115//! #[proxy_enum::proxy(Animal)]
116//! mod proxy {
117//!     enum Animal {
118//!         Cat(Cat),
119//!         Lion(Lion),
120//!         Mouse(Mouse)
121//!     }
122//!
123//!     #[external(std::string::ToString)]
124//!     trait ToString {
125//!         fn to_string(&self) -> String;
126//!     }
127//!
128//!     #[implement]
129//!     impl std::string::ToString for Animal {}
130//! }
131//! ```
132
133extern crate proc_macro2;
134
135use proc_macro::TokenStream;
136use std::collections::HashMap;
137
138use proc_macro2::TokenStream as TokenStream2;
139use syn::visit_mut::VisitMut;
140use syn::{
141    parse2, parse_macro_input, Attribute, Fields, FnArg, Ident, ImplItem, Item, ItemEnum, ItemImpl,
142    ItemMod, ItemTrait, Pat, PatType, Path, PathArguments, Signature, TraitItem, Type, Variant,
143};
144
145use quote::quote;
146
147const IMPL_ATTR: &str = "implement";
148const EXT_ATTR: &str = "external";
149
150fn attr_idx(attrs: &[Attribute], ident: &str) -> Option<usize> {
151    (0..attrs.len()).find(|idx| attrs[*idx].path.is_ident(ident))
152}
153
154fn pop_attr(attrs: &mut Vec<Attribute>, ident: &str) -> Option<Attribute> {
155    attr_idx(attrs, ident).map(|idx| attrs.remove(idx))
156}
157
158fn find_attr<'a>(attrs: &'a [Attribute], ident: &str) -> Option<&'a Attribute> {
159    attr_idx(&attrs, ident).map(|idx| &attrs[idx])
160}
161
162fn gen_static_method_call(receiver: TokenStream2, signature: &Signature) -> TokenStream2 {
163    let method_ident = &signature.ident;
164
165    let args = signature
166        .inputs
167        .iter()
168        .skip(1) // `self`
169        .map(|a| match a {
170            FnArg::Typed(PatType { pat, .. }) => match &**pat {
171                Pat::Ident(ident) => &ident.ident,
172                other => panic!("unsupported pattern in parameter: `{}`", quote! { #other }),
173            },
174            _ => panic!("parameter binding must be an identifier"),
175        });
176
177    quote! { #receiver::#method_ident(__self #(, #args)*) }
178}
179
180struct WrapperVariant {
181    variant: Variant,
182    wrapped: Type,
183}
184
185impl From<Variant> for WrapperVariant {
186    fn from(variant: Variant) -> Self {
187        match &variant.fields {
188            Fields::Unnamed(a) if a.unnamed.len() == 1 => WrapperVariant {
189                variant: variant.clone(),
190                wrapped: a.unnamed.first().unwrap().ty.clone(),
191            },
192            _ => panic!("expected a variant with a single unnamed value"),
193        }
194    }
195}
196
197fn gen_match_block(
198    variants: &[WrapperVariant],
199    action: impl Fn(&WrapperVariant) -> TokenStream2,
200) -> TokenStream2 {
201    let branches = variants
202        .iter()
203        .map(|variant| {
204            let action = action(&variant);
205            let ident = &variant.variant.ident;
206            quote! { Self::#ident(__self) => #action }
207        })
208        .collect::<Vec<_>>();
209
210    quote! {
211        match self {
212            #(#branches),*
213        }
214    }
215}
216
217fn has_self_param(sig: &Signature) -> bool {
218    sig.inputs
219        .first()
220        .map(|param| match param {
221            FnArg::Receiver(..) => true,
222            FnArg::Typed(PatType { pat, .. }) => match &**pat {
223                Pat::Ident(ident) => &ident.ident.to_string() == "self",
224                _ => false,
225            },
226        })
227        .unwrap_or(false)
228}
229
230/// populate an empty `#[implement] impl Trait for ProxyEnum {}` block
231fn implement_trait(
232    trait_decl: &ItemTrait,
233    variants: &[WrapperVariant],
234    pseudo_impl: &mut ItemImpl,
235) {
236    assert!(pseudo_impl.items.is_empty());
237
238    let trait_ident = &trait_decl.ident;
239
240    let proxy_methods = trait_decl.items.iter().map(|i| match i {
241        TraitItem::Method(i) => {
242            let sig = &i.sig;
243            if !has_self_param(sig) {
244                match &i.default {
245                    Some(..) => return parse2(quote! { #i }).unwrap(),
246                    None => panic!(
247                        "`{}` has no self parameter or default implementation",
248                        quote! { #sig }
249                    ),
250                }
251            }
252
253            let match_block = gen_match_block(variants, |_| gen_static_method_call(quote! { #trait_ident }, sig));
254            let tokens = quote! { #sig { #match_block } };
255            parse2::<ImplItem>(tokens).unwrap()
256        }
257        _ => panic!(
258            "impl block annotated with `#[{}]` may only contain methods",
259            IMPL_ATTR
260        ),
261    });
262
263    pseudo_impl.items = proxy_methods.collect();
264}
265
266/// populate methods in a `impl ProxyEnum { #[implement] fn method(&self) {} }` block
267fn implement_raw(variants: &[WrapperVariant], pseudo_impl: &mut ItemImpl) {
268    pseudo_impl
269        .items
270        .iter_mut()
271        .flat_map(|i| match i {
272            ImplItem::Method(method) => pop_attr(&mut method.attrs, IMPL_ATTR).map(|_| method),
273            _ => None,
274        })
275        .for_each(|mut method| {
276            if !method.block.stmts.is_empty() {
277                panic!("method annotated with `#[{}]` must be empty", IMPL_ATTR)
278            }
279
280            let match_block = gen_match_block(variants, |variant| {
281                let ty = &variant.wrapped;
282                gen_static_method_call(quote! { #ty }, &method.sig)
283            });
284            let body = quote! { { #match_block } };
285            method.block = syn::parse2(body).unwrap();
286        });
287}
288
289struct GenerateProxyImpl {
290    proxy_enum: Ident,
291    variants: Option<Vec<WrapperVariant>>,
292    trait_defs: HashMap<String, ItemTrait>,
293}
294
295impl GenerateProxyImpl {
296    fn new(proxy_enum: Ident) -> Self {
297        GenerateProxyImpl {
298            proxy_enum,
299            variants: None,
300            trait_defs: HashMap::new(),
301        }
302    }
303
304    fn get_variants(&self) -> &[WrapperVariant] {
305        self.variants
306            .as_ref()
307            .unwrap_or_else(|| panic!("proxy enum must be defined first"))
308            .as_slice()
309    }
310
311    fn store_trait_decl(&mut self, attr: Option<Path>, decl: ItemTrait) {
312        let mut path = match attr {
313            Some(path) => quote! { #path },
314            None => {
315                let ident = &decl.ident;
316                quote! { #ident }
317            }
318        }
319        .to_string();
320        path.retain(|c| !c.is_whitespace());
321        self.trait_defs.insert(path, decl);
322    }
323
324    fn get_trait_decl(&self, mut path: Path) -> &ItemTrait {
325        path.segments
326            .iter_mut()
327            .for_each(|seg| seg.arguments = PathArguments::None);
328        let mut path = quote! { #path }.to_string();
329        path.retain(|c| !c.is_whitespace());
330
331        self.trait_defs
332            .get(&path)
333            .unwrap_or_else(|| panic!("missing declaration of trait `{}`", path))
334    }
335    
336    fn impl_from_variants(&self, module: &mut ItemMod) {
337        let proxy_enum = &self.proxy_enum;
338        for WrapperVariant { variant, wrapped, .. } in self.get_variants() {
339            let variant = &variant.ident;
340            let tokens = quote! {
341                impl From<#wrapped> for #proxy_enum {
342                    fn from(from: #wrapped) -> Self {
343                        #proxy_enum :: #variant(from)
344                    }
345                }
346            };
347            let from_impl: ItemImpl = syn::parse2(tokens).unwrap();
348            module.content.as_mut().unwrap().1.push(from_impl.into()); 
349        }
350    }
351}
352
353impl VisitMut for GenerateProxyImpl {
354    // store variants of our enum
355    fn visit_item_enum_mut(&mut self, i: &mut ItemEnum) {
356        if i.ident != self.proxy_enum {
357            return;
358        }
359        assert!(self.variants.is_none());
360
361        self.variants = Some(
362            i.variants
363                .iter()
364                .cloned()
365                .map(WrapperVariant::from)
366                .collect(),
367        );
368    }
369
370    fn visit_item_impl_mut(&mut self, impl_block: &mut ItemImpl) {
371        match impl_block.trait_.as_mut() {
372            // `impl Type { #[implement] fn abc() {} }
373            None => implement_raw(self.get_variants(), impl_block),
374            // #[implement] `impl Trait for Type {}`
375            Some((_, path, _)) => {
376                if pop_attr(&mut impl_block.attrs, IMPL_ATTR).is_some() {
377                    implement_trait(
378                        self.get_trait_decl(path.clone()),
379                        self.get_variants(),
380                        impl_block,
381                    );
382                }
383            }
384        };
385    }
386
387    fn visit_item_mod_mut(&mut self, module: &mut ItemMod) {
388        syn::visit_mut::visit_item_mod_mut(self, module);
389        // remove all items annotated with external
390        module.content.as_mut().unwrap().1.retain(|item| {
391            if let Item::Trait(ItemTrait { attrs, .. }) = item {
392                find_attr(&attrs, EXT_ATTR).is_none()
393            } else {
394                true
395            }
396        });
397        self.impl_from_variants(module);
398    }
399
400    // scan for trait declarations and store them
401    fn visit_item_trait_mut(&mut self, trait_def: &mut ItemTrait) {
402        let ext_attr = find_attr(&trait_def.attrs, EXT_ATTR).map(|attr| attr.parse_args().unwrap());
403        self.store_trait_decl(ext_attr, trait_def.clone());
404    }
405}
406
407#[proc_macro_attribute]
408pub fn proxy(attr: TokenStream, item: TokenStream) -> TokenStream {
409    let mut module = parse_macro_input!(item as ItemMod);
410    let proxy_enum = parse_macro_input!(attr as Ident);
411
412    GenerateProxyImpl::new(proxy_enum).visit_item_mod_mut(&mut module);
413
414    TokenStream::from(quote! { #module })
415}