1#![allow(clippy::let_and_return)]
2#![deny(
3 unused_variables,
4 mutable_borrow_reservation_conflict,
5 dead_code,
6 unused_must_use,
7 unused_imports
8)]
9
10extern crate proc_macro;
11
12use heck::ToSnakeCase;
13use proc_macro2::TokenStream;
14use proc_macro_error::*;
15use quote::*;
16use syn::{spanned::Spanned, *};
17
18#[proc_macro_derive(FromAttributes, attributes(bae))]
20#[proc_macro_error]
21pub fn from_attributes(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
22 let item = parse_macro_input!(input as ItemStruct);
23 FromAttributes::new(item).expand().into()
24}
25
26#[derive(Debug)]
27struct FromAttributes {
28 item: ItemStruct,
29 tokens: TokenStream,
30}
31
32impl FromAttributes {
33 fn new(item: ItemStruct) -> Self {
34 Self {
35 item,
36 tokens: TokenStream::new(),
37 }
38 }
39
40 fn expand(mut self) -> TokenStream {
41 self.expand_from_attributes_method();
42 self.expand_parse_impl();
43
44 if std::env::var("BAE_DEBUG").is_ok() {
45 eprintln!("{}", self.tokens);
46 }
47
48 self.tokens
49 }
50
51 fn struct_name(&self) -> &Ident {
52 &self.item.ident
53 }
54
55 fn attr_name(&self) -> LitStr {
56 let struct_name = self.struct_name();
57 let mut name = struct_name.to_string().to_snake_case();
58 for attr in &self.item.attrs {
59 if attr.path.is_ident("bae") {
60 if let Ok(lit) = attr.parse_args::<syn::LitStr>() {
61 name = lit.value();
62 }
63 }
64 }
65 LitStr::new(&name, struct_name.span())
66 }
67
68 fn expand_from_attributes_method(&mut self) {
69 let struct_name = self.struct_name();
70 let attr_name = self.attr_name().value();
71
72 let code = quote! {
73 impl ::better_bae::TryFromAttributes for #struct_name {
74 fn attr_name() -> &'static str {
75 #attr_name
76 }
77
78 fn try_from_attributes(attrs: &[::syn::Attribute]) -> ::syn::Result<Option<Self>> {
79 use ::syn::spanned::Spanned;
80
81 for attr in attrs {
82 match attr.path.get_ident() {
83 Some(ident) if ident == #attr_name => {
84 return Some(syn::parse2::<Self>(attr.tokens.clone())).transpose()
85 }
86 _ => {},
88 }
89 }
90
91 Ok(None)
92 }
93 }
94 };
95 self.tokens.extend(code);
96 }
97
98 fn expand_parse_impl(&mut self) {
99 let struct_name = self.struct_name();
100 let attr_name = self.attr_name();
101
102 let variable_declarations = self.item.fields.iter().map(|field| {
103 let name = &field.ident;
104 quote! { let mut #name = std::option::Option::None; }
105 });
106
107 let match_arms = self.item.fields.iter().map(|field| {
108 let field_name = get_field_name(field);
109 let pattern = LitStr::new(&field_name.to_string(), field.span());
110
111 if field_is_switch(field) {
112 quote! {
113 #pattern => {
114 #field_name = std::option::Option::Some(());
115 }
116 }
117 } else {
118 quote! {
119 #pattern => {
120 content.parse::<syn::Token![=]>()?;
121 #field_name = std::option::Option::Some(content.parse()?);
122 }
123 }
124 }
125 });
126
127 let unwrap_mandatory_fields = self
128 .item
129 .fields
130 .iter()
131 .filter(|field| !field_is_optional(field))
132 .map(|field| {
133 let field_name = get_field_name(field);
134 let arg_name = LitStr::new(&field_name.to_string(), field.span());
135
136 quote! {
137 let #field_name = if let std::option::Option::Some(#field_name) = #field_name {
138 #field_name
139 } else {
140 return syn::Result::Err(
141 input.error(
142 &format!("`#[{}]` is missing `{}` argument", #attr_name, #arg_name),
143 )
144 );
145 };
146 }
147 });
148
149 let set_fields = self.item.fields.iter().map(|field| {
150 let field_name = get_field_name(field);
151 quote! { #field_name, }
152 });
153
154 let code = quote! {
155 impl syn::parse::Parse for #struct_name {
156 #[allow(unreachable_code, unused_imports, unused_variables)]
157 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
158 #(#variable_declarations)*
159
160 let content;
161 syn::parenthesized!(content in input);
162
163 while !content.is_empty() {
164 let bae_attr_ident = content.parse::<syn::Ident>()?;
165
166 match &*bae_attr_ident.to_string() {
167 #(#match_arms)*
168 _ => {
169 content.parse::<proc_macro2::TokenStream>()?;
170 }
171 }
172
173 content.parse::<syn::Token![,]>().ok();
174 }
175
176 #(#unwrap_mandatory_fields)*
177
178 syn::Result::Ok(Self { #(#set_fields)* })
179 }
180 }
181 };
182 self.tokens.extend(code);
183 }
184}
185
186fn get_field_name(field: &Field) -> &Ident {
187 field
188 .ident
189 .as_ref()
190 .unwrap_or_else(|| abort!(field.span(), "Field without a name"))
191}
192
193fn field_is_optional(field: &Field) -> bool {
194 let type_path = if let Type::Path(type_path) = &field.ty {
195 type_path
196 } else {
197 return false;
198 };
199
200 let ident = &type_path
201 .path
202 .segments
203 .last()
204 .unwrap_or_else(|| abort!(field.span(), "Empty type path"))
205 .ident;
206
207 ident == "Option"
208}
209
210fn field_is_switch(field: &Field) -> bool {
211 let unit_type = syn::parse_str::<Type>("()").unwrap();
212 inner_type(&field.ty) == Some(&unit_type)
213}
214
215fn inner_type(ty: &Type) -> Option<&Type> {
216 let type_path = if let Type::Path(type_path) = ty {
217 type_path
218 } else {
219 return None;
220 };
221
222 let ty_args = &type_path
223 .path
224 .segments
225 .last()
226 .unwrap_or_else(|| abort!(ty.span(), "Empty type path"))
227 .arguments;
228
229 let ty_args = if let PathArguments::AngleBracketed(ty_args) = ty_args {
230 ty_args
231 } else {
232 return None;
233 };
234
235 let generic_arg = &ty_args
236 .args
237 .last()
238 .unwrap_or_else(|| abort!(ty_args.span(), "Empty generic argument"));
239
240 let ty = if let GenericArgument::Type(ty) = generic_arg {
241 ty
242 } else {
243 return None;
244 };
245
246 Some(ty)
247}
248
249#[cfg(test)]
250mod test {
251 #[allow(unused_imports)]
252 use super::*;
253
254 #[test]
255 fn test_ui() {
256 let t = trybuild::TestCases::new();
257 t.pass("tests/compile_pass/*.rs");
258 t.compile_fail("tests/compile_fail/*.rs");
259 }
260}