associated_derive/
lib.rs

1//! Derive macro for `Associated`.
2//!
3//! ## Usage
4//!
5//! Add `#[derive(Associated)]` to an enum definition. This is not compatible with structs or unions.
6//!
7//! When deriving `Associated` you must include a `#[associate(Type = associated_type)]` attribute beneath
8//! the `#[derive(Associated)]` attribute, replacing `associated_type` with the type of the constants you
9//! want to associate with the enum variants.
10//!
11//! For each and **every** variant of the enum you must include either a `#[assoc(expr)]` or
12//! `#[assoc_const(const_expr)]` attribute above or inline before the variant, with `expr` or `const_expr`
13//! replaced with the expression or value you want to associate.
14//!
15//! ### Example
16//!
17//! ```rust
18//! #[derive(Associated)]
19//! #[associated(Type = &'static str)]
20//! enum Phonetic {
21//!     #[assoc_const("Alpha")] Alpha,
22//!     #[assoc(&"Bravo")] // #[assoc] requires an expression of type &'static Type
23//!     Bravo = 3 // supports explicit enum discriminants
24//!     // ...
25//! }
26//!
27//! Phonetic::Alpha.get_associated() // returns a static lifetime reference to "Alpha"
28//! ```
29//!
30//! #### Generated Implementation
31//!
32//! ```rust
33//! impl associated::Associated for Phonetic {
34//!     type AssociatedType = &'static str;
35//!     fn get_associated(&self) -> &'static Self::AssociatedType {
36//!         match self {
37//!             Phonetic::Alpha => {
38//!                 const ASSOCIATED: &'static str = "Alpha";
39//!                 &ASSOCIATED
40//!             },
41//!             Phonetic::Bravo => &"Bravo",
42//!         }
43//!     }
44//! }
45//! ```
46//!
47//! ### Note
48//!
49//! If you give a variant both an `#[assoc]` and an `#[assoc_const]` attribute, or multiple `#[assoc]`
50//! or `#[assoc_const]` attributes, only the first will be considered. Including more than one is not
51//! currently an error, but this **will** change so only use one `#[assoc]` or `#[assoc_const]`
52//! attribute per variant.
53//!
54//! See [associated](https://docs.rs/associated) for retrieving associated constants.
55
56use proc_macro::{self, TokenStream};
57use proc_macro2::TokenStream as TokenStream2;
58use quote::quote;
59use syn::{
60    parse::{Error as ParseError, Parse, ParseStream, Result as ParseResult},
61    parse_macro_input,
62    punctuated::Punctuated,
63    spanned::Spanned,
64    token::Comma,
65    Attribute, Binding, DeriveInput, Expr, Fields, Ident, Type, Variant,
66};
67
68struct Args {
69    assoc_type: Type,
70}
71
72enum AssocKind {
73    Constant,
74    Static,
75}
76
77struct Assoc<'a> {
78    kind: AssocKind,
79    attr: &'a Attribute,
80}
81
82impl Parse for Args {
83    fn parse(input: ParseStream) -> ParseResult<Self> {
84        let b = Binding::parse(input)?;
85        if b.ident.to_string() == "Type" {
86            return Ok(Args { assoc_type: b.ty });
87        }
88        Err(ParseError::new(b.ident.span(), "Expected `Type`"))
89    }
90}
91
92fn generate_match_body(
93    enum_ident: &Ident,
94    associated_type: &Type,
95    associated_variants: &Vec<(&Ident, &Fields, Expr, AssocKind)>,
96) -> TokenStream2 {
97    let mut match_block = TokenStream2::new();
98    match_block.extend(
99        associated_variants
100            .iter()
101            .map(|(variant_ident, fields, expr, kind)| {
102                let pattern = match fields {
103                    syn::Fields::Named(_) => quote! {{..}},
104                    syn::Fields::Unnamed(_) => quote! {(..)},
105                    syn::Fields::Unit => quote! {},
106                };
107                match kind {
108                    AssocKind::Constant => {
109                        quote! {
110                            #enum_ident::#variant_ident #pattern => {
111                                const ASSOCIATED: #associated_type = #expr;
112                                &ASSOCIATED
113                            },
114                        }
115                    }
116                    AssocKind::Static => {
117                        quote! {
118                            #enum_ident::#variant_ident #pattern => #expr,
119                        }
120                    }
121                }
122            }),
123    );
124    match_block
125}
126
127/// Takes in a sequence of enum variants and parses their attributes to return a list of (variant, associated value) groupings.
128///
129/// Fields are included in the grouping to control which pattern glyph to generate for that variant.
130/// AssocKind holds whether the attribute was assoc or assoc_const
131fn parse_associated_values<'a>(
132    variants: &'a Punctuated<Variant, Comma>,
133    enum_ident: &Ident,
134) -> Result<Vec<(&'a Ident, &'a Fields, Expr, AssocKind)>, TokenStream> {
135    let mut associated_values = Vec::new();
136    for v in variants.iter() {
137        if let Some(assoc) = v.attrs.iter().find_map(|attr| match attr.path.get_ident() {
138            Some(i) => {
139                let i = i.to_string();
140                if i == "assoc" {
141                    Some(Assoc {
142                        kind: AssocKind::Static,
143                        attr,
144                    })
145                } else if i == "assoc_const" {
146                    Some(Assoc {
147                        kind: AssocKind::Constant,
148                        attr,
149                    })
150                } else {
151                    None
152                }
153            }
154            None => None,
155        }) {
156            let expr = match assoc.attr.parse_args::<Expr>() {
157                Ok(expr) => expr,
158                Err(e) => return Err(e.to_compile_error().into()),
159            };
160
161            associated_values.push((&v.ident, &v.fields, expr, assoc.kind));
162        } else {
163            return Err(ParseError::new(
164                v.span(),
165                format!(
166                    "Cannot derive `Associated` for `{}`: Missing `assoc` or `assoc_const` attribute on variant `{}`",
167                    enum_ident.to_string(),
168                    v.ident.to_string()
169                )
170            )
171            .to_compile_error()
172            .into());
173        }
174    }
175    Ok(associated_values)
176}
177
178/// See [crate-level] documentation.
179///
180/// [crate-level]: crate
181#[proc_macro_derive(Associated, attributes(associated, assoc, assoc_const))]
182pub fn associated_derive(input: TokenStream) -> TokenStream {
183    let DeriveInput {
184        attrs,
185        vis: _,
186        ident,
187        generics,
188        data,
189    } = parse_macro_input!(input);
190    let associated = match (&attrs).iter().find(|&attr| match attr.path.get_ident() {
191        Some(i) => i.to_string() == "associated",
192        None => false,
193    }) {
194        Some(attr) => attr,
195        None => {
196            return ParseError::new(ident.span(), "Missing `associated` attribute")
197                .to_compile_error()
198                .into()
199        }
200    };
201    let args = match associated.parse_args::<Args>() {
202        Ok(a) => a,
203        Err(e) => return e.to_compile_error().into(),
204    };
205
206    let variants = match data {
207        syn::Data::Struct(s) => {
208            return ParseError::new(
209                s.struct_token.span,
210                "Cannot derive `Associated` for structs",
211            )
212            .to_compile_error()
213            .into()
214        }
215        syn::Data::Union(u) => {
216            return ParseError::new(u.union_token.span, "Cannot derive `Associated` for unions")
217                .to_compile_error()
218                .into()
219        }
220        syn::Data::Enum(data) => data.variants,
221    };
222    let associated_variants = match parse_associated_values(&variants, &ident) {
223        Ok(v) => v,
224        Err(e) => return e,
225    };
226    let associated_type = args.assoc_type;
227
228    let match_block = generate_match_body(&ident, &associated_type, &associated_variants);
229    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
230    let impl_block = quote! {
231        impl #impl_generics associated::Associated for #ident #ty_generics #where_clause {
232            type AssociatedType = #associated_type;
233            fn get_associated(&self) -> &'static Self::AssociatedType {
234                match self {
235                    #match_block
236                }
237            }
238        }
239    };
240    impl_block.into()
241}