better_bae_macros/
lib.rs

1#![allow(clippy::let_and_return)]
2#![deny(
3    unused_variables,
4    mutable_borrow_reservation_conflict,
5    dead_code,
6    unused_must_use,
7    unused_imports
8)]
9
10extern crate proc_macro;
11
12use heck::ToSnakeCase;
13use proc_macro2::TokenStream;
14use proc_macro_error::*;
15use quote::*;
16use syn::{spanned::Spanned, *};
17
18/// See root module docs for more info.
19#[proc_macro_derive(FromAttributes, attributes(bae))]
20#[proc_macro_error]
21pub fn from_attributes(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
22    let item = parse_macro_input!(input as ItemStruct);
23    FromAttributes::new(item).expand().into()
24}
25
26#[derive(Debug)]
27struct FromAttributes {
28    item: ItemStruct,
29    tokens: TokenStream,
30}
31
32impl FromAttributes {
33    fn new(item: ItemStruct) -> Self {
34        Self {
35            item,
36            tokens: TokenStream::new(),
37        }
38    }
39
40    fn expand(mut self) -> TokenStream {
41        self.expand_from_attributes_method();
42        self.expand_parse_impl();
43
44        if std::env::var("BAE_DEBUG").is_ok() {
45            eprintln!("{}", self.tokens);
46        }
47
48        self.tokens
49    }
50
51    fn struct_name(&self) -> &Ident {
52        &self.item.ident
53    }
54
55    fn attr_name(&self) -> LitStr {
56        let struct_name = self.struct_name();
57        let mut name = struct_name.to_string().to_snake_case();
58        for attr in &self.item.attrs {
59            if attr.path.is_ident("bae") {
60                if let Ok(lit) = attr.parse_args::<syn::LitStr>() {
61                    name = lit.value();
62                }
63            }
64        }
65        LitStr::new(&name, struct_name.span())
66    }
67
68    fn expand_from_attributes_method(&mut self) {
69        let struct_name = self.struct_name();
70        let attr_name = self.attr_name().value();
71
72        let code = quote! {
73            impl ::better_bae::TryFromAttributes for #struct_name {
74                fn attr_name() -> &'static str {
75                    #attr_name
76                }
77
78                fn try_from_attributes(attrs: &[::syn::Attribute]) -> ::syn::Result<Option<Self>> {
79                    use ::syn::spanned::Spanned;
80
81                    for attr in attrs {
82                        match attr.path.get_ident() {
83                            Some(ident) if ident == #attr_name => {
84                                return Some(syn::parse2::<Self>(attr.tokens.clone())).transpose()
85                            }
86                            // Ignore other attributes
87                            _ => {},
88                        }
89                    }
90
91                    Ok(None)
92                }
93            }
94        };
95        self.tokens.extend(code);
96    }
97
98    fn expand_parse_impl(&mut self) {
99        let struct_name = self.struct_name();
100        let attr_name = self.attr_name();
101
102        let variable_declarations = self.item.fields.iter().map(|field| {
103            let name = &field.ident;
104            quote! { let mut #name = std::option::Option::None; }
105        });
106
107        let match_arms = self.item.fields.iter().map(|field| {
108            let field_name = get_field_name(field);
109            let pattern = LitStr::new(&field_name.to_string(), field.span());
110
111            if field_is_switch(field) {
112                quote! {
113                    #pattern => {
114                        #field_name = std::option::Option::Some(());
115                    }
116                }
117            } else {
118                quote! {
119                    #pattern => {
120                        content.parse::<syn::Token![=]>()?;
121                        #field_name = std::option::Option::Some(content.parse()?);
122                    }
123                }
124            }
125        });
126
127        let unwrap_mandatory_fields = self
128            .item
129            .fields
130            .iter()
131            .filter(|field| !field_is_optional(field))
132            .map(|field| {
133                let field_name = get_field_name(field);
134                let arg_name = LitStr::new(&field_name.to_string(), field.span());
135
136                quote! {
137                    let #field_name = if let std::option::Option::Some(#field_name) = #field_name {
138                        #field_name
139                    } else {
140                        return syn::Result::Err(
141                            input.error(
142                                &format!("`#[{}]` is missing `{}` argument", #attr_name, #arg_name),
143                            )
144                        );
145                    };
146                }
147            });
148
149        let set_fields = self.item.fields.iter().map(|field| {
150            let field_name = get_field_name(field);
151            quote! { #field_name, }
152        });
153
154        let code = quote! {
155            impl syn::parse::Parse for #struct_name {
156                #[allow(unreachable_code, unused_imports, unused_variables)]
157                fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
158                    #(#variable_declarations)*
159
160                    let content;
161                    syn::parenthesized!(content in input);
162
163                    while !content.is_empty() {
164                        let bae_attr_ident = content.parse::<syn::Ident>()?;
165
166                        match &*bae_attr_ident.to_string() {
167                            #(#match_arms)*
168                            _ => {
169                                content.parse::<proc_macro2::TokenStream>()?;
170                            }
171                        }
172
173                        content.parse::<syn::Token![,]>().ok();
174                    }
175
176                    #(#unwrap_mandatory_fields)*
177
178                    syn::Result::Ok(Self { #(#set_fields)* })
179                }
180            }
181        };
182        self.tokens.extend(code);
183    }
184}
185
186fn get_field_name(field: &Field) -> &Ident {
187    field
188        .ident
189        .as_ref()
190        .unwrap_or_else(|| abort!(field.span(), "Field without a name"))
191}
192
193fn field_is_optional(field: &Field) -> bool {
194    let type_path = if let Type::Path(type_path) = &field.ty {
195        type_path
196    } else {
197        return false;
198    };
199
200    let ident = &type_path
201        .path
202        .segments
203        .last()
204        .unwrap_or_else(|| abort!(field.span(), "Empty type path"))
205        .ident;
206
207    ident == "Option"
208}
209
210fn field_is_switch(field: &Field) -> bool {
211    let unit_type = syn::parse_str::<Type>("()").unwrap();
212    inner_type(&field.ty) == Some(&unit_type)
213}
214
215fn inner_type(ty: &Type) -> Option<&Type> {
216    let type_path = if let Type::Path(type_path) = ty {
217        type_path
218    } else {
219        return None;
220    };
221
222    let ty_args = &type_path
223        .path
224        .segments
225        .last()
226        .unwrap_or_else(|| abort!(ty.span(), "Empty type path"))
227        .arguments;
228
229    let ty_args = if let PathArguments::AngleBracketed(ty_args) = ty_args {
230        ty_args
231    } else {
232        return None;
233    };
234
235    let generic_arg = &ty_args
236        .args
237        .last()
238        .unwrap_or_else(|| abort!(ty_args.span(), "Empty generic argument"));
239
240    let ty = if let GenericArgument::Type(ty) = generic_arg {
241        ty
242    } else {
243        return None;
244    };
245
246    Some(ty)
247}
248
249#[cfg(test)]
250mod test {
251    #[allow(unused_imports)]
252    use super::*;
253
254    #[test]
255    fn test_ui() {
256        let t = trybuild::TestCases::new();
257        t.pass("tests/compile_pass/*.rs");
258        t.compile_fail("tests/compile_fail/*.rs");
259    }
260}