Skip to main content

macron_impl_from/
lib.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::quote;
6
7/// The implementation of [std::convert::From] trait
8#[proc_macro_derive(From, attributes(from))]
9pub fn impl_from(input: TokenStream) -> TokenStream {
10    let syn::DeriveInput {
11        ident, data, attrs, ..
12    } = syn::parse_macro_input!(input as syn::DeriveInput);
13
14    let struct_fields = match &data {
15        syn::Data::Struct(st) => Some(&st.fields),
16        _ => None,
17    };
18
19    let global_impls = read_attr_values(&attrs, struct_fields)
20        .into_iter()
21        .filter_map(|attr| match attr {
22            AttrValue::Custom { ty, expr } => Some(quote! {
23                impl ::std::convert::From<#ty> for #ident {
24                    fn from(value: #ty) -> Self {
25                        #expr
26                    }
27                }
28            }),
29            AttrValue::Skip => None,
30        });
31
32    match &data {
33        syn::Data::Struct(st) => {
34            let mut struct_field_impls = Vec::new();
35
36            let has_skip = read_attr_values(&attrs, struct_fields)
37                .iter()
38                .any(|a| matches!(a, AttrValue::Skip));
39
40            if st.fields.len() == 1 && !has_skip {
41                let field = st.fields.iter().next().unwrap();
42                let ty = &field.ty;
43                let body = match &field.ident {
44                    Some(id) => quote! { Self { #id: value } },
45                    None => quote! { Self(value) },
46                };
47
48                struct_field_impls.push(quote! {
49                    impl ::std::convert::From<#ty> for #ident {
50                        fn from(value: #ty) -> Self {
51                            #body
52                        }
53                    }
54                });
55            }
56
57            quote! {
58                #(#global_impls)*
59                #(#struct_field_impls)*
60            }
61            .into()
62        }
63
64        syn::Data::Enum(en) => {
65            let var_impls = en.variants.iter().map(
66                |syn::Variant {
67                     ident: var_ident,
68                     attrs,
69                     fields,
70                     ..
71                 }| {
72                    let parsed_attrs = read_attr_values(&attrs, Some(fields));
73
74                    if parsed_attrs.iter().any(|a| matches!(a, AttrValue::Skip)) {
75                        return quote! {};
76                    }
77
78                    let mut generated_for_variant = Vec::new();
79
80                    if fields.len() == 1 {
81                        let field = fields.iter().next().unwrap();
82                        let ty = &field.ty;
83                        let base_output = match fields {
84                            syn::Fields::Named(_) => {
85                                let field_ident = &field.ident;
86                                quote! { Self::#var_ident { #field_ident: value } }
87                            }
88                            syn::Fields::Unnamed(_) => quote! { Self::#var_ident(value) },
89                            syn::Fields::Unit => quote! { Self::#var_ident },
90                        };
91
92                        generated_for_variant.push(quote! {
93                            impl ::std::convert::From<#ty> for #ident {
94                                fn from(value: #ty) -> Self {
95                                    #base_output
96                                }
97                            }
98                        });
99                    }
100
101                    for attr in parsed_attrs {
102                        if let AttrValue::Custom { ty, expr } = attr {
103                            let custom_output = match &fields {
104                                syn::Fields::Named(_) => quote! { Self::#var_ident { #expr } },
105                                syn::Fields::Unnamed(_) => quote! { Self::#var_ident(#expr) },
106                                syn::Fields::Unit => quote! { Self::#var_ident },
107                            };
108                            generated_for_variant.push(quote! {
109                                impl ::std::convert::From<#ty> for #ident {
110                                    fn from(value: #ty) -> Self {
111                                        #custom_output
112                                    }
113                                }
114                            });
115                        }
116                    }
117
118                    quote! { #(#generated_for_variant)* }
119                },
120            );
121
122            quote! {
123                #(#global_impls)*
124                #(#var_impls)*
125            }
126            .into()
127        }
128
129        _ => panic!("Expected a 'struct' or 'enum'"),
130    }
131}
132
133enum AttrValue {
134    Custom { ty: syn::Type, expr: TokenStream2 },
135    Skip,
136}
137
138fn read_attr_values(attrs: &[syn::Attribute], _fields: Option<&syn::Fields>) -> Vec<AttrValue> {
139    attrs
140        .iter()
141        .filter(|attr| attr.path().is_ident("from"))
142        .map(|attr| match &attr.meta {
143            syn::Meta::List(list) => {
144                let args: FromArgs = list.parse_args().expect("Invalid arguments");
145                match args {
146                    FromArgs::With(ty, path) => AttrValue::Custom {
147                        ty: ty.clone(),
148                        expr: quote! { #path(value) },
149                    },
150                    FromArgs::Expr(ty, expr) => AttrValue::Custom {
151                        ty: ty.clone(),
152                        expr: quote! { #expr },
153                    },
154                    FromArgs::Skip => AttrValue::Skip,
155                }
156            }
157            _ => panic!("Unsupported attribute format. Use #[from(skip)] or #[from(Type, ...)]"),
158        })
159        .collect()
160}
161
162enum FromArgs {
163    With(syn::Type, syn::Path),
164    Expr(syn::Type, syn::Expr),
165    Skip,
166}
167
168impl syn::parse::Parse for FromArgs {
169    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
170        if input.peek(syn::Ident) {
171            let fork = input.fork();
172            let ident: syn::Ident = fork.parse()?;
173            if ident == "skip" {
174                input.parse::<syn::Ident>()?;
175                return Ok(FromArgs::Skip);
176            }
177        }
178
179        let ty: syn::Type = input.parse()?;
180        input.parse::<syn::token::Comma>()?;
181        let ident: syn::Ident = input.parse()?;
182        input.parse::<syn::token::Eq>()?;
183
184        if ident == "with" {
185            let path: syn::Path = input.parse()?;
186            Ok(FromArgs::With(ty, path))
187        } else if ident == "expr" {
188            let expr: syn::Expr = input.parse()?;
189            Ok(FromArgs::Expr(ty, expr))
190        } else {
191            Err(syn::Error::new(
192                ident.span(),
193                "expected `with`, `expr` or `skip`",
194            ))
195        }
196    }
197}