1use proc_macro::{self, TokenStream};
2use quote::quote;
3use regex::Regex;
4use std::collections::HashSet;
5use syn::{
6 self, parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, Attribute,
7 DataEnum, DataStruct, DeriveInput, ExprLit, Lit, LitStr, Meta, Path, Token,
8};
9
10#[proc_macro_derive(FromRegex, attributes(regex))]
11pub fn derive_from_regex(input: TokenStream) -> TokenStream {
12 let derive_input: DeriveInput = parse_macro_input!(input as DeriveInput);
13
14 impl_derive_from_regex(&derive_input).into()
15}
16
17fn impl_derive_from_regex(derive_input: &DeriveInput) -> proc_macro2::TokenStream {
18 match &derive_input.data {
19 syn::Data::Struct(data_struct) => {
20 impl_derive_from_regex_for_struct(derive_input, data_struct)
21 }
22 syn::Data::Enum(data_enum) => impl_derive_from_regex_for_enum(derive_input, data_enum),
23 syn::Data::Union(_) => syn::Error::new(
24 derive_input.ident.span(),
25 "FromRegex cannot be derived for unions",
26 )
27 .to_compile_error(),
28 }
29}
30
31struct FromRegexAttr {
33 pattern_literal: LitStr,
35}
36
37fn impl_derive_from_regex_for_struct(
38 derive_input: &DeriveInput,
39 data: &DataStruct,
40) -> proc_macro2::TokenStream {
41 let ident = &derive_input.ident;
42
43 let attr_args = match find_regex_attr(&derive_input.attrs) {
44 Some(attr) => match get_regex_attr(derive_input, attr) {
45 Ok(attr_args) => attr_args,
46 Err(err) => return err.into_compile_error(),
47 },
48
49 None => {
50 return syn::Error::new(derive_input.ident.span(), "missing regex attribute")
51 .into_compile_error()
52 }
53 };
54
55 let pattern_string = attr_args.pattern_literal.value();
57 let pattern = pattern_string.as_str();
58
59 let re = match Regex::new(pattern) {
60 Ok(re) => re,
61 Err(e) => {
62 return syn::Error::new_spanned(attr_args.pattern_literal, format!("{}", e))
63 .into_compile_error()
64 }
65 };
66
67 let return_type: Path = derive_input.ident.clone().into();
68
69 let impl_block: proc_macro2::TokenStream = match &data.fields {
70 syn::Fields::Named(fields_named) => {
71 impl_for_named_struct(fields_named, &re, pattern, return_type)
72 }
73 syn::Fields::Unnamed(fields_unnamed) => {
74 impl_for_tuple_struct(fields_unnamed, &re, pattern, return_type)
75 }
76 syn::Fields::Unit => impl_for_unit_struct(pattern, return_type),
77 };
78
79 let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();
80 quote! {
81 impl #impl_generics FromRegex for #ident #ty_generics #where_clause {
82 fn parse(input: &str) -> std::result::Result<#ident, std::string::String> {
83 #impl_block
84 Err(format!{"couldn't parse from \"{}\"", input}.to_string())
85 }
86 }
87 }
88}
89
90fn find_regex_attr(attrs: &[Attribute]) -> Option<&Attribute> {
92 attrs.iter().find(|attr| attr.path().is_ident("regex"))
93}
94
95fn get_regex_attr(
97 derive_input: &DeriveInput,
98 attr: &Attribute,
99) -> Result<FromRegexAttr, syn::Error> {
100 let mut pattern_literal: Option<LitStr> = None;
101
102 match attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_separated_nonempty) {
103 Ok(nested) => {
104 for meta in nested {
105 let meta_span = meta.span();
106 match meta {
107 Meta::NameValue(name_value) if name_value.path.is_ident("pattern") => {
109 match name_value.value {
110 syn::Expr::Lit(ExprLit {
111 lit: Lit::Str(lit_value),
112 ..
113 }) => pattern_literal = Some(lit_value),
114 _ => {
115 return Err(syn::Error::new(
117 meta_span,
118 "expcted `pattern = \"...\"` argument",
119 ));
120 }
121 }
122 }
123 _ => {
124 return Err(syn::Error::new_spanned(
125 meta,
126 "unsupported attribute argument",
127 ))
128 }
129 }
130 }
131 }
132 Err(err) => return Err(err),
133 }
134
135 let pattern_literal = match pattern_literal {
136 Some(p) => p,
137 None => {
138 return Err(syn::Error::new(
139 derive_input.ident.span(),
140 "expcted `pattern = \"...\"` argument",
141 ));
142 }
143 };
144
145 Ok(FromRegexAttr { pattern_literal })
146}
147
148fn impl_for_named_struct(
149 fields_named: &syn::FieldsNamed,
150 re: &Regex,
151 pattern: &str,
152 return_type: Path,
153) -> proc_macro2::TokenStream {
154 let expected_cap_groups: HashSet<String> = fields_named
155 .named
156 .iter()
157 .filter_map(|field| field.ident.clone().map(|name| name.to_string()))
158 .collect();
159 let actual_cap_groups: HashSet<String> = re
160 .capture_names()
161 .skip(1)
162 .filter_map(|name| name.map(|name| name.to_string()))
163 .collect();
164
165 let missing_groups: HashSet<String> = expected_cap_groups
167 .difference(&actual_cap_groups)
168 .cloned()
169 .collect();
170
171 let extra_groups: HashSet<String> = actual_cap_groups
173 .difference(&expected_cap_groups)
174 .cloned()
175 .collect();
176
177 let mut group_errors = Vec::new();
178
179 if !missing_groups.is_empty() {
180 group_errors.push(
181 syn::Error::new_spanned(
182 fields_named,
183 format!(
184 "missing capture groups for struct fields: {}",
185 missing_groups
186 .into_iter()
187 .collect::<Vec<String>>()
188 .join(", ")
189 ),
190 )
191 .into_compile_error(),
192 );
193 }
194 if !extra_groups.is_empty() {
195 group_errors.push(
196 syn::Error::new_spanned(
197 fields_named,
198 format!(
199 "these capture groups don't match any struct fields: {}",
200 extra_groups.into_iter().collect::<Vec<String>>().join(", ")
201 ),
202 )
203 .into_compile_error(),
204 );
205 }
206
207 if !group_errors.is_empty() {
208 return quote! {#(#group_errors)*};
209 }
210
211 let field_exprs = fields_named.named.iter().map(|field| {
212 let field_ident = field.ident.clone().expect("field of named struct");
213 let field_name = format!("{field_ident}");
214 let field_ty = &field.ty;
215
216 quote! {
217 #field_ident: caps[#field_name].parse::<#field_ty>().map_err(|err| err.to_string())?
218 }
219 });
220
221 quote! {
222 let re = ::regex::Regex::new(#pattern).expect("Regex validated at compile time");
223 if let Some(caps) = re.captures(input) {
224 return Ok(#return_type{ #(#field_exprs),* })
225 }
226 }
227}
228
229fn impl_for_tuple_struct(
230 fields_unnamed: &syn::FieldsUnnamed,
231 re: &Regex,
232 pattern: &str,
233 return_type: Path,
234) -> proc_macro2::TokenStream {
235 let actual_groups = re.captures_len() - 1;
236 let expected_groups = fields_unnamed.unnamed.len();
237
238 if actual_groups > expected_groups {
239 return syn::Error::new_spanned(
240 fields_unnamed,
241 format!("too many capturing groups: expected {expected_groups}, got {actual_groups}"),
242 )
243 .into_compile_error();
244 } else if expected_groups > actual_groups {
245 return syn::Error::new_spanned(
246 fields_unnamed,
247 format!("missing capturing groups: expected {expected_groups}, got {actual_groups}"),
248 )
249 .into_compile_error();
250 }
251
252 let field_exprs = fields_unnamed.unnamed.iter().enumerate().map(|(i, field)| {
253 let index = i + 1;
254 let field_ty = &field.ty;
255 quote! {
256 caps[#index].parse::<#field_ty>().map_err(|err| err.to_string())?
257
258 }
259 });
260
261 quote! {
262 let re = ::regex::Regex::new(#pattern).expect("Regex validated at compile time");
263 if let Some(caps) = re.captures(input) {
264 return Ok(#return_type( #(#field_exprs),* ))
265 }
266 }
267}
268
269fn impl_for_unit_struct(pattern: &str, return_type: Path) -> proc_macro2::TokenStream {
270 quote! {
271 let re = ::regex::Regex::new(#pattern).expect("Regex validated at compile time");
272 if re.is_match(input) {
273 return Ok(#return_type);
274 }
275 }
276}
277
278fn impl_derive_from_regex_for_enum(
279 derive_input: &DeriveInput,
280 data: &DataEnum,
281) -> proc_macro2::TokenStream {
282 let enum_ident = &derive_input.ident;
283
284 let impls = data
285 .variants
286 .iter()
287 .map(|variant| -> proc_macro2::TokenStream {
288 let attr_args = match find_regex_attr(&variant.attrs) {
289 Some(attr) => match get_regex_attr(derive_input, attr) {
290 Ok(attr_args) => attr_args,
291 Err(err) => return err.into_compile_error(),
292 },
293
294 None => {
295 return syn::Error::new(variant.ident.span(), "missing regex attribute")
296 .into_compile_error()
297 }
298 };
299
300 let pattern_string = attr_args.pattern_literal.value();
302 let pattern = pattern_string.as_str();
303
304 let re = match Regex::new(pattern) {
305 Ok(re) => re,
306 Err(e) => {
307 return syn::Error::new_spanned(attr_args.pattern_literal, format!("{}", e))
308 .into_compile_error()
309 }
310 };
311
312 let variant_ident = &variant.ident;
313 let return_type = parse_quote!(#enum_ident::#variant_ident);
314
315 match &variant.fields {
316 syn::Fields::Named(fields_named) => {
317 impl_for_named_struct(fields_named, &re, pattern, return_type)
318 }
319 syn::Fields::Unnamed(fields_unnamed) => {
320 impl_for_tuple_struct(fields_unnamed, &re, pattern, return_type)
321 }
322 syn::Fields::Unit => impl_for_unit_struct(pattern, return_type),
323 }
324 });
325
326 let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();
327 quote! {
328 impl #impl_generics FromRegex for #enum_ident #ty_generics #where_clause {
329 fn parse(input: &str) -> std::result::Result<#enum_ident, std::string::String> {
330 #(#impls)*
331 Err(format!{"couldn't parse from \"{}\"", input}.to_string())
332 }
333 }
334 }
335}