derive_regex_proc_macro/
lib.rs

1use proc_macro::{self, TokenStream};
2use quote::quote;
3use regex::Regex;
4use std::collections::HashSet;
5use syn::{
6    self, parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, Attribute,
7    DataEnum, DataStruct, DeriveInput, ExprLit, Lit, LitStr, Meta, Path, Token,
8};
9
10#[proc_macro_derive(FromRegex, attributes(regex))]
11pub fn derive_from_regex(input: TokenStream) -> TokenStream {
12    let derive_input: DeriveInput = parse_macro_input!(input as DeriveInput);
13
14    impl_derive_from_regex(&derive_input).into()
15}
16
17fn impl_derive_from_regex(derive_input: &DeriveInput) -> proc_macro2::TokenStream {
18    match &derive_input.data {
19        syn::Data::Struct(data_struct) => {
20            impl_derive_from_regex_for_struct(derive_input, data_struct)
21        }
22        syn::Data::Enum(data_enum) => impl_derive_from_regex_for_enum(derive_input, data_enum),
23        syn::Data::Union(_) => syn::Error::new(
24            derive_input.ident.span(),
25            "FromRegex cannot be derived for unions",
26        )
27        .to_compile_error(),
28    }
29}
30
31/// The configuration options for the #[regex(...)] attribute
32struct FromRegexAttr {
33    /// The pattern to match for the struct/variant
34    pattern_literal: LitStr,
35}
36
37fn impl_derive_from_regex_for_struct(
38    derive_input: &DeriveInput,
39    data: &DataStruct,
40) -> proc_macro2::TokenStream {
41    let ident = &derive_input.ident;
42
43    let attr_args = match find_regex_attr(&derive_input.attrs) {
44        Some(attr) => match get_regex_attr(derive_input, attr) {
45            Ok(attr_args) => attr_args,
46            Err(err) => return err.into_compile_error(),
47        },
48
49        None => {
50            return syn::Error::new(derive_input.ident.span(), "missing regex attribute")
51                .into_compile_error()
52        }
53    };
54
55    // needed to prevent the String from being dropped too soon
56    let pattern_string = attr_args.pattern_literal.value();
57    let pattern = pattern_string.as_str();
58
59    let re = match Regex::new(pattern) {
60        Ok(re) => re,
61        Err(e) => {
62            return syn::Error::new_spanned(attr_args.pattern_literal, format!("{}", e))
63                .into_compile_error()
64        }
65    };
66
67    let return_type: Path = derive_input.ident.clone().into();
68
69    let impl_block: proc_macro2::TokenStream = match &data.fields {
70        syn::Fields::Named(fields_named) => {
71            impl_for_named_struct(fields_named, &re, pattern, return_type)
72        }
73        syn::Fields::Unnamed(fields_unnamed) => {
74            impl_for_tuple_struct(fields_unnamed, &re, pattern, return_type)
75        }
76        syn::Fields::Unit => impl_for_unit_struct(pattern, return_type),
77    };
78
79    let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();
80    quote! {
81        impl #impl_generics FromRegex for #ident #ty_generics #where_clause {
82            fn parse(input: &str) -> std::result::Result<#ident, std::string::String> {
83                #impl_block
84                Err(format!{"couldn't parse from \"{}\"", input}.to_string())
85            }
86        }
87    }
88}
89
90/// Find the `#[regex(...)]` attribite in the item's attributes
91fn find_regex_attr(attrs: &[Attribute]) -> Option<&Attribute> {
92    attrs.iter().find(|attr| attr.path().is_ident("regex"))
93}
94
95/// Return the parameters of the `#[regex(...)]` attribute as a `FromRegexAttr `instance
96fn get_regex_attr(
97    derive_input: &DeriveInput,
98    attr: &Attribute,
99) -> Result<FromRegexAttr, syn::Error> {
100    let mut pattern_literal: Option<LitStr> = None;
101
102    match attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_separated_nonempty) {
103        Ok(nested) => {
104            for meta in nested {
105                let meta_span = meta.span();
106                match meta {
107                    // #[regex(pattern = "...")]
108                    Meta::NameValue(name_value) if name_value.path.is_ident("pattern") => {
109                        match name_value.value {
110                            syn::Expr::Lit(ExprLit {
111                                lit: Lit::Str(lit_value),
112                                ..
113                            }) => pattern_literal = Some(lit_value),
114                            _ => {
115                                // TODO: make span cover the whole meta item, not just the name
116                                return Err(syn::Error::new(
117                                    meta_span,
118                                    "expcted `pattern = \"...\"` argument",
119                                ));
120                            }
121                        }
122                    }
123                    _ => {
124                        return Err(syn::Error::new_spanned(
125                            meta,
126                            "unsupported attribute argument",
127                        ))
128                    }
129                }
130            }
131        }
132        Err(err) => return Err(err),
133    }
134
135    let pattern_literal = match pattern_literal {
136        Some(p) => p,
137        None => {
138            return Err(syn::Error::new(
139                derive_input.ident.span(),
140                "expcted `pattern = \"...\"` argument",
141            ));
142        }
143    };
144
145    Ok(FromRegexAttr { pattern_literal })
146}
147
148fn impl_for_named_struct(
149    fields_named: &syn::FieldsNamed,
150    re: &Regex,
151    pattern: &str,
152    return_type: Path,
153) -> proc_macro2::TokenStream {
154    let expected_cap_groups: HashSet<String> = fields_named
155        .named
156        .iter()
157        .filter_map(|field| field.ident.clone().map(|name| name.to_string()))
158        .collect();
159    let actual_cap_groups: HashSet<String> = re
160        .capture_names()
161        .skip(1)
162        .filter_map(|name| name.map(|name| name.to_string()))
163        .collect();
164
165    // struct fields not captured in a group
166    let missing_groups: HashSet<String> = expected_cap_groups
167        .difference(&actual_cap_groups)
168        .cloned()
169        .collect();
170
171    // capturing groups not matching any struct field
172    let extra_groups: HashSet<String> = actual_cap_groups
173        .difference(&expected_cap_groups)
174        .cloned()
175        .collect();
176
177    let mut group_errors = Vec::new();
178
179    if !missing_groups.is_empty() {
180        group_errors.push(
181            syn::Error::new_spanned(
182                fields_named,
183                format!(
184                    "missing capture groups for struct fields: {}",
185                    missing_groups
186                        .into_iter()
187                        .collect::<Vec<String>>()
188                        .join(", ")
189                ),
190            )
191            .into_compile_error(),
192        );
193    }
194    if !extra_groups.is_empty() {
195        group_errors.push(
196            syn::Error::new_spanned(
197                fields_named,
198                format!(
199                    "these capture groups don't match any struct fields: {}",
200                    extra_groups.into_iter().collect::<Vec<String>>().join(", ")
201                ),
202            )
203            .into_compile_error(),
204        );
205    }
206
207    if !group_errors.is_empty() {
208        return quote! {#(#group_errors)*};
209    }
210
211    let field_exprs = fields_named.named.iter().map(|field| {
212        let field_ident = field.ident.clone().expect("field of named struct");
213        let field_name = format!("{field_ident}");
214        let field_ty = &field.ty;
215
216        quote! {
217            #field_ident: caps[#field_name].parse::<#field_ty>().map_err(|err| err.to_string())?
218        }
219    });
220
221    quote! {
222        {
223            use once_cell::sync::Lazy;
224            static RE: Lazy<::regex::Regex> = Lazy::new(|| ::regex::Regex::new(#pattern).expect("Regex validated at compile time"));
225            if let Some(caps) = RE.captures(input) {
226                return Ok(#return_type{ #(#field_exprs),* })
227            }
228        }
229    }
230}
231
232fn impl_for_tuple_struct(
233    fields_unnamed: &syn::FieldsUnnamed,
234    re: &Regex,
235    pattern: &str,
236    return_type: Path,
237) -> proc_macro2::TokenStream {
238    let actual_groups = re.captures_len() - 1;
239    let expected_groups = fields_unnamed.unnamed.len();
240
241    if actual_groups > expected_groups {
242        return syn::Error::new_spanned(
243            fields_unnamed,
244            format!("too many capturing groups: expected {expected_groups}, got {actual_groups}"),
245        )
246        .into_compile_error();
247    } else if expected_groups > actual_groups {
248        return syn::Error::new_spanned(
249            fields_unnamed,
250            format!("missing capturing groups: expected {expected_groups}, got {actual_groups}"),
251        )
252        .into_compile_error();
253    }
254
255    let field_exprs = fields_unnamed.unnamed.iter().enumerate().map(|(i, field)| {
256        let index = i + 1;
257        let field_ty = &field.ty;
258        quote! {
259            caps[#index].parse::<#field_ty>().map_err(|err| err.to_string())?
260
261        }
262    });
263
264    quote! {
265        {
266            use once_cell::sync::Lazy;
267            static RE: Lazy<::regex::Regex> = Lazy::new(|| ::regex::Regex::new(#pattern).expect("Regex validated at compile time"));
268            if let Some(caps) = RE.captures(input) {
269                return Ok(#return_type( #(#field_exprs),* ))
270            }
271       }
272    }
273}
274
275fn impl_for_unit_struct(pattern: &str, return_type: Path) -> proc_macro2::TokenStream {
276    quote! {
277        {
278            use once_cell::sync::Lazy;
279            static RE: Lazy<::regex::Regex> = Lazy::new(|| ::regex::Regex::new(#pattern).expect("Regex validated at compile time"));
280            if RE.is_match(input) {
281                return Ok(#return_type);
282            }
283        }
284    }
285}
286
287fn impl_derive_from_regex_for_enum(
288    derive_input: &DeriveInput,
289    data: &DataEnum,
290) -> proc_macro2::TokenStream {
291    let enum_ident = &derive_input.ident;
292
293    let impls = data
294        .variants
295        .iter()
296        .map(|variant| -> proc_macro2::TokenStream {
297            let attr_args = match find_regex_attr(&variant.attrs) {
298                Some(attr) => match get_regex_attr(derive_input, attr) {
299                    Ok(attr_args) => attr_args,
300                    Err(err) => return err.into_compile_error(),
301                },
302
303                None => {
304                    return syn::Error::new(variant.ident.span(), "missing regex attribute")
305                        .into_compile_error()
306                }
307            };
308
309            // needed to prevent the String from being dropped too soon
310            let pattern_string = attr_args.pattern_literal.value();
311            let pattern = pattern_string.as_str();
312
313            let re = match Regex::new(pattern) {
314                Ok(re) => re,
315                Err(e) => {
316                    return syn::Error::new_spanned(attr_args.pattern_literal, format!("{}", e))
317                        .into_compile_error()
318                }
319            };
320
321            let variant_ident = &variant.ident;
322            let return_type = parse_quote!(#enum_ident::#variant_ident);
323
324            match &variant.fields {
325                syn::Fields::Named(fields_named) => {
326                    impl_for_named_struct(fields_named, &re, pattern, return_type)
327                }
328                syn::Fields::Unnamed(fields_unnamed) => {
329                    impl_for_tuple_struct(fields_unnamed, &re, pattern, return_type)
330                }
331                syn::Fields::Unit => impl_for_unit_struct(pattern, return_type),
332            }
333        });
334
335    let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();
336    quote! {
337        impl #impl_generics FromRegex for #enum_ident #ty_generics #where_clause {
338            fn parse(input: &str) -> std::result::Result<#enum_ident, std::string::String> {
339                #(#impls)*
340                Err(format!{"couldn't parse from \"{}\"", input}.to_string())
341            }
342        }
343    }
344}