1use proc_macro::{self, TokenStream};
57use proc_macro2::TokenStream as TokenStream2;
58use quote::quote;
59use syn::{
60 parse::{Error as ParseError, Parse, ParseStream, Result as ParseResult},
61 parse_macro_input,
62 punctuated::Punctuated,
63 spanned::Spanned,
64 token::Comma,
65 Attribute, Binding, DeriveInput, Expr, Fields, Ident, Type, Variant,
66};
67
68struct Args {
69 assoc_type: Type,
70}
71
72enum AssocKind {
73 Constant,
74 Static,
75}
76
77struct Assoc<'a> {
78 kind: AssocKind,
79 attr: &'a Attribute,
80}
81
82impl Parse for Args {
83 fn parse(input: ParseStream) -> ParseResult<Self> {
84 let b = Binding::parse(input)?;
85 if b.ident.to_string() == "Type" {
86 return Ok(Args { assoc_type: b.ty });
87 }
88 Err(ParseError::new(b.ident.span(), "Expected `Type`"))
89 }
90}
91
92fn generate_match_body(
93 enum_ident: &Ident,
94 associated_type: &Type,
95 associated_variants: &Vec<(&Ident, &Fields, Expr, AssocKind)>,
96) -> TokenStream2 {
97 let mut match_block = TokenStream2::new();
98 match_block.extend(
99 associated_variants
100 .iter()
101 .map(|(variant_ident, fields, expr, kind)| {
102 let pattern = match fields {
103 syn::Fields::Named(_) => quote! {{..}},
104 syn::Fields::Unnamed(_) => quote! {(..)},
105 syn::Fields::Unit => quote! {},
106 };
107 match kind {
108 AssocKind::Constant => {
109 quote! {
110 #enum_ident::#variant_ident #pattern => {
111 const ASSOCIATED: #associated_type = #expr;
112 &ASSOCIATED
113 },
114 }
115 }
116 AssocKind::Static => {
117 quote! {
118 #enum_ident::#variant_ident #pattern => #expr,
119 }
120 }
121 }
122 }),
123 );
124 match_block
125}
126
127fn parse_associated_values<'a>(
132 variants: &'a Punctuated<Variant, Comma>,
133 enum_ident: &Ident,
134) -> Result<Vec<(&'a Ident, &'a Fields, Expr, AssocKind)>, TokenStream> {
135 let mut associated_values = Vec::new();
136 for v in variants.iter() {
137 if let Some(assoc) = v.attrs.iter().find_map(|attr| match attr.path.get_ident() {
138 Some(i) => {
139 let i = i.to_string();
140 if i == "assoc" {
141 Some(Assoc {
142 kind: AssocKind::Static,
143 attr,
144 })
145 } else if i == "assoc_const" {
146 Some(Assoc {
147 kind: AssocKind::Constant,
148 attr,
149 })
150 } else {
151 None
152 }
153 }
154 None => None,
155 }) {
156 let expr = match assoc.attr.parse_args::<Expr>() {
157 Ok(expr) => expr,
158 Err(e) => return Err(e.to_compile_error().into()),
159 };
160
161 associated_values.push((&v.ident, &v.fields, expr, assoc.kind));
162 } else {
163 return Err(ParseError::new(
164 v.span(),
165 format!(
166 "Cannot derive `Associated` for `{}`: Missing `assoc` or `assoc_const` attribute on variant `{}`",
167 enum_ident.to_string(),
168 v.ident.to_string()
169 )
170 )
171 .to_compile_error()
172 .into());
173 }
174 }
175 Ok(associated_values)
176}
177
178#[proc_macro_derive(Associated, attributes(associated, assoc, assoc_const))]
182pub fn associated_derive(input: TokenStream) -> TokenStream {
183 let DeriveInput {
184 attrs,
185 vis: _,
186 ident,
187 generics,
188 data,
189 } = parse_macro_input!(input);
190 let associated = match (&attrs).iter().find(|&attr| match attr.path.get_ident() {
191 Some(i) => i.to_string() == "associated",
192 None => false,
193 }) {
194 Some(attr) => attr,
195 None => {
196 return ParseError::new(ident.span(), "Missing `associated` attribute")
197 .to_compile_error()
198 .into()
199 }
200 };
201 let args = match associated.parse_args::<Args>() {
202 Ok(a) => a,
203 Err(e) => return e.to_compile_error().into(),
204 };
205
206 let variants = match data {
207 syn::Data::Struct(s) => {
208 return ParseError::new(
209 s.struct_token.span,
210 "Cannot derive `Associated` for structs",
211 )
212 .to_compile_error()
213 .into()
214 }
215 syn::Data::Union(u) => {
216 return ParseError::new(u.union_token.span, "Cannot derive `Associated` for unions")
217 .to_compile_error()
218 .into()
219 }
220 syn::Data::Enum(data) => data.variants,
221 };
222 let associated_variants = match parse_associated_values(&variants, &ident) {
223 Ok(v) => v,
224 Err(e) => return e,
225 };
226 let associated_type = args.assoc_type;
227
228 let match_block = generate_match_body(&ident, &associated_type, &associated_variants);
229 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
230 let impl_block = quote! {
231 impl #impl_generics associated::Associated for #ident #ty_generics #where_clause {
232 type AssociatedType = #associated_type;
233 fn get_associated(&self) -> &'static Self::AssociatedType {
234 match self {
235 #match_block
236 }
237 }
238 }
239 };
240 impl_block.into()
241}