1use proc_macro::{self, TokenStream};
2use quote::quote;
3use regex::Regex;
4use std::collections::HashSet;
5use syn::{
6 self, parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, Attribute,
7 DataEnum, DataStruct, DeriveInput, ExprLit, Lit, LitStr, Meta, Path, Token,
8};
9
10#[proc_macro_derive(FromRegex, attributes(regex))]
11pub fn derive_from_regex(input: TokenStream) -> TokenStream {
12 let derive_input: DeriveInput = parse_macro_input!(input as DeriveInput);
13
14 impl_derive_from_regex(&derive_input).into()
15}
16
17fn impl_derive_from_regex(derive_input: &DeriveInput) -> proc_macro2::TokenStream {
18 match &derive_input.data {
19 syn::Data::Struct(data_struct) => {
20 impl_derive_from_regex_for_struct(derive_input, data_struct)
21 }
22 syn::Data::Enum(data_enum) => impl_derive_from_regex_for_enum(derive_input, data_enum),
23 syn::Data::Union(_) => syn::Error::new(
24 derive_input.ident.span(),
25 "FromRegex cannot be derived for unions",
26 )
27 .to_compile_error(),
28 }
29}
30
31struct FromRegexAttr {
33 pattern_literal: LitStr,
35}
36
37fn impl_derive_from_regex_for_struct(
38 derive_input: &DeriveInput,
39 data: &DataStruct,
40) -> proc_macro2::TokenStream {
41 let ident = &derive_input.ident;
42
43 let attr_args = match find_regex_attr(&derive_input.attrs) {
44 Some(attr) => match get_regex_attr(derive_input, attr) {
45 Ok(attr_args) => attr_args,
46 Err(err) => return err.into_compile_error(),
47 },
48
49 None => {
50 return syn::Error::new(derive_input.ident.span(), "missing regex attribute")
51 .into_compile_error()
52 }
53 };
54
55 let pattern_string = attr_args.pattern_literal.value();
57 let pattern = pattern_string.as_str();
58
59 let re = match Regex::new(pattern) {
60 Ok(re) => re,
61 Err(e) => {
62 return syn::Error::new_spanned(attr_args.pattern_literal, format!("{}", e))
63 .into_compile_error()
64 }
65 };
66
67 let return_type: Path = derive_input.ident.clone().into();
68
69 let impl_block: proc_macro2::TokenStream = match &data.fields {
70 syn::Fields::Named(fields_named) => {
71 impl_for_named_struct(fields_named, &re, pattern, return_type)
72 }
73 syn::Fields::Unnamed(fields_unnamed) => {
74 impl_for_tuple_struct(fields_unnamed, &re, pattern, return_type)
75 }
76 syn::Fields::Unit => impl_for_unit_struct(pattern, return_type),
77 };
78
79 let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();
80 quote! {
81 impl #impl_generics FromRegex for #ident #ty_generics #where_clause {
82 fn parse(input: &str) -> std::result::Result<#ident, std::string::String> {
83 #impl_block
84 Err(format!{"couldn't parse from \"{}\"", input}.to_string())
85 }
86 }
87 }
88}
89
90fn find_regex_attr(attrs: &[Attribute]) -> Option<&Attribute> {
92 attrs.iter().find(|attr| attr.path().is_ident("regex"))
93}
94
95fn get_regex_attr(
97 derive_input: &DeriveInput,
98 attr: &Attribute,
99) -> Result<FromRegexAttr, syn::Error> {
100 let mut pattern_literal: Option<LitStr> = None;
101
102 match attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_separated_nonempty) {
103 Ok(nested) => {
104 for meta in nested {
105 let meta_span = meta.span();
106 match meta {
107 Meta::NameValue(name_value) if name_value.path.is_ident("pattern") => {
109 match name_value.value {
110 syn::Expr::Lit(ExprLit {
111 lit: Lit::Str(lit_value),
112 ..
113 }) => pattern_literal = Some(lit_value),
114 _ => {
115 return Err(syn::Error::new(
117 meta_span,
118 "expcted `pattern = \"...\"` argument",
119 ));
120 }
121 }
122 }
123 _ => {
124 return Err(syn::Error::new_spanned(
125 meta,
126 "unsupported attribute argument",
127 ))
128 }
129 }
130 }
131 }
132 Err(err) => return Err(err),
133 }
134
135 let pattern_literal = match pattern_literal {
136 Some(p) => p,
137 None => {
138 return Err(syn::Error::new(
139 derive_input.ident.span(),
140 "expcted `pattern = \"...\"` argument",
141 ));
142 }
143 };
144
145 Ok(FromRegexAttr { pattern_literal })
146}
147
148fn impl_for_named_struct(
149 fields_named: &syn::FieldsNamed,
150 re: &Regex,
151 pattern: &str,
152 return_type: Path,
153) -> proc_macro2::TokenStream {
154 let expected_cap_groups: HashSet<String> = fields_named
155 .named
156 .iter()
157 .filter_map(|field| field.ident.clone().map(|name| name.to_string()))
158 .collect();
159 let actual_cap_groups: HashSet<String> = re
160 .capture_names()
161 .skip(1)
162 .filter_map(|name| name.map(|name| name.to_string()))
163 .collect();
164
165 let missing_groups: HashSet<String> = expected_cap_groups
167 .difference(&actual_cap_groups)
168 .cloned()
169 .collect();
170
171 let extra_groups: HashSet<String> = actual_cap_groups
173 .difference(&expected_cap_groups)
174 .cloned()
175 .collect();
176
177 let mut group_errors = Vec::new();
178
179 if !missing_groups.is_empty() {
180 group_errors.push(
181 syn::Error::new_spanned(
182 fields_named,
183 format!(
184 "missing capture groups for struct fields: {}",
185 missing_groups
186 .into_iter()
187 .collect::<Vec<String>>()
188 .join(", ")
189 ),
190 )
191 .into_compile_error(),
192 );
193 }
194 if !extra_groups.is_empty() {
195 group_errors.push(
196 syn::Error::new_spanned(
197 fields_named,
198 format!(
199 "these capture groups don't match any struct fields: {}",
200 extra_groups.into_iter().collect::<Vec<String>>().join(", ")
201 ),
202 )
203 .into_compile_error(),
204 );
205 }
206
207 if !group_errors.is_empty() {
208 return quote! {#(#group_errors)*};
209 }
210
211 let field_exprs = fields_named.named.iter().map(|field| {
212 let field_ident = field.ident.clone().expect("field of named struct");
213 let field_name = format!("{field_ident}");
214 let field_ty = &field.ty;
215
216 quote! {
217 #field_ident: caps[#field_name].parse::<#field_ty>().map_err(|err| err.to_string())?
218 }
219 });
220
221 quote! {
222 {
223 use once_cell::sync::Lazy;
224 static RE: Lazy<::regex::Regex> = Lazy::new(|| ::regex::Regex::new(#pattern).expect("Regex validated at compile time"));
225 if let Some(caps) = RE.captures(input) {
226 return Ok(#return_type{ #(#field_exprs),* })
227 }
228 }
229 }
230}
231
232fn impl_for_tuple_struct(
233 fields_unnamed: &syn::FieldsUnnamed,
234 re: &Regex,
235 pattern: &str,
236 return_type: Path,
237) -> proc_macro2::TokenStream {
238 let actual_groups = re.captures_len() - 1;
239 let expected_groups = fields_unnamed.unnamed.len();
240
241 if actual_groups > expected_groups {
242 return syn::Error::new_spanned(
243 fields_unnamed,
244 format!("too many capturing groups: expected {expected_groups}, got {actual_groups}"),
245 )
246 .into_compile_error();
247 } else if expected_groups > actual_groups {
248 return syn::Error::new_spanned(
249 fields_unnamed,
250 format!("missing capturing groups: expected {expected_groups}, got {actual_groups}"),
251 )
252 .into_compile_error();
253 }
254
255 let field_exprs = fields_unnamed.unnamed.iter().enumerate().map(|(i, field)| {
256 let index = i + 1;
257 let field_ty = &field.ty;
258 quote! {
259 caps[#index].parse::<#field_ty>().map_err(|err| err.to_string())?
260
261 }
262 });
263
264 quote! {
265 {
266 use once_cell::sync::Lazy;
267 static RE: Lazy<::regex::Regex> = Lazy::new(|| ::regex::Regex::new(#pattern).expect("Regex validated at compile time"));
268 if let Some(caps) = RE.captures(input) {
269 return Ok(#return_type( #(#field_exprs),* ))
270 }
271 }
272 }
273}
274
275fn impl_for_unit_struct(pattern: &str, return_type: Path) -> proc_macro2::TokenStream {
276 quote! {
277 {
278 use once_cell::sync::Lazy;
279 static RE: Lazy<::regex::Regex> = Lazy::new(|| ::regex::Regex::new(#pattern).expect("Regex validated at compile time"));
280 if RE.is_match(input) {
281 return Ok(#return_type);
282 }
283 }
284 }
285}
286
287fn impl_derive_from_regex_for_enum(
288 derive_input: &DeriveInput,
289 data: &DataEnum,
290) -> proc_macro2::TokenStream {
291 let enum_ident = &derive_input.ident;
292
293 let impls = data
294 .variants
295 .iter()
296 .map(|variant| -> proc_macro2::TokenStream {
297 let attr_args = match find_regex_attr(&variant.attrs) {
298 Some(attr) => match get_regex_attr(derive_input, attr) {
299 Ok(attr_args) => attr_args,
300 Err(err) => return err.into_compile_error(),
301 },
302
303 None => {
304 return syn::Error::new(variant.ident.span(), "missing regex attribute")
305 .into_compile_error()
306 }
307 };
308
309 let pattern_string = attr_args.pattern_literal.value();
311 let pattern = pattern_string.as_str();
312
313 let re = match Regex::new(pattern) {
314 Ok(re) => re,
315 Err(e) => {
316 return syn::Error::new_spanned(attr_args.pattern_literal, format!("{}", e))
317 .into_compile_error()
318 }
319 };
320
321 let variant_ident = &variant.ident;
322 let return_type = parse_quote!(#enum_ident::#variant_ident);
323
324 match &variant.fields {
325 syn::Fields::Named(fields_named) => {
326 impl_for_named_struct(fields_named, &re, pattern, return_type)
327 }
328 syn::Fields::Unnamed(fields_unnamed) => {
329 impl_for_tuple_struct(fields_unnamed, &re, pattern, return_type)
330 }
331 syn::Fields::Unit => impl_for_unit_struct(pattern, return_type),
332 }
333 });
334
335 let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();
336 quote! {
337 impl #impl_generics FromRegex for #enum_ident #ty_generics #where_clause {
338 fn parse(input: &str) -> std::result::Result<#enum_ident, std::string::String> {
339 #(#impls)*
340 Err(format!{"couldn't parse from \"{}\"", input}.to_string())
341 }
342 }
343 }
344}