1mod render;
7
8extern crate proc_macro;
9use proc_macro::TokenStream;
10
11use lazy_format::lazy_format;
12use proc_macro2::TokenStream as TokenStream2;
13use quote::{format_ident, quote};
14use regex_automata::meta::Regex;
15use regex_syntax::{
16 hir::{self, Capture, Hir, HirKind, Repetition},
17 parse as parse_regex,
18};
19use render::hir_expression;
20use syn::{
21 parse::{Parse, ParseStream},
22 parse_macro_input,
23 spanned::Spanned,
24 Ident, Token,
25};
26use thiserror::Error;
27
28use self::render::{CaptureType, HirType, InputType, RegexType};
29
30struct Request {
31 public: Option<Token![pub]>,
32 type_name: syn::Ident,
33 regex: syn::LitStr,
34}
35
36impl Parse for Request {
37 fn parse(input: ParseStream) -> syn::Result<Self> {
38 let public = input.parse()?;
39 let type_name = input.parse()?;
40 let _eq: Token![=] = input.parse()?;
41 let regex = input.parse()?;
42
43 Ok(Self {
44 public,
45 type_name,
46 regex,
47 })
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
52enum HirRepState {
53 Definite,
54 Optional,
55 Repeating,
56}
57
58impl HirRepState {
59 fn from_reps(repetition: &Repetition) -> Self {
60 match (repetition.min, repetition.max) {
61 (1, Some(1)) => Self::Definite,
62 (0, Some(1)) => Self::Optional,
63 _ => Self::Repeating,
64 }
65 }
66
67 fn and(self, other: HirRepState) -> Self {
68 Ord::max(self, other)
69 }
70
71 fn with(self, repetition: &Repetition) -> Self {
72 self.and(Self::from_reps(repetition))
73 }
74}
75
76#[derive(Debug, Clone, Copy)]
77struct GroupInfo<'a> {
78 name: &'a str,
79 optional: bool,
80 index: u32,
81}
82
83fn get_group_index(groups: &[GroupInfo<'_>]) -> u32 {
84 groups.last().map(|group| group.index).unwrap_or(0) + 1
85}
86
87#[derive(Debug, Error)]
88enum HirError {
89 #[error("duplicate group name: {0:?}")]
90 DuplicateGroupName(String),
91
92 #[error("capture group {0:?} is repeating; capture groups can't repeat")]
93 RepeatingCaptureGroup(String),
94
95 #[error("capture group name {0:?} is not a valid rust identifier")]
96 BadName(String),
97}
98
99fn process_hir_recurse<'a>(
104 hir: &'a Hir,
105 groups: &mut Vec<GroupInfo<'a>>,
106 state: HirRepState,
107) -> Result<Hir, HirError> {
108 match *hir.kind() {
109 HirKind::Empty => Ok(Hir::empty()),
111 HirKind::Literal(hir::Literal(ref lit)) => Ok(Hir::literal(lit.clone())),
112 HirKind::Class(ref class) => Ok(Hir::class(class.clone())),
113 HirKind::Look(look) => Ok(Hir::look(look)),
114
115 HirKind::Repetition(ref repetition) => {
117 let state = state.with(repetition);
118 let sub = process_hir_recurse(&repetition.sub, groups, state)?;
119
120 Ok(Hir::repetition(Repetition {
121 sub: Box::new(sub),
122 ..*repetition
123 }))
124 }
125
126 HirKind::Capture(ref capture) => {
129 let Some(name) = capture.name.as_deref() else {
130 return process_hir_recurse(&capture.sub, groups, state);
132 };
133
134 let _ident: Ident =
137 syn::parse_str(name).map_err(|_| HirError::BadName(name.to_owned()))?;
138
139 if groups.iter().any(|group| group.name == name) {
141 return Err(HirError::DuplicateGroupName(name.to_owned()));
142 }
143
144 if state == HirRepState::Repeating {
146 return Err(HirError::RepeatingCaptureGroup(name.to_owned()));
147 }
148
149 let group_index = get_group_index(groups);
150
151 groups.push(GroupInfo {
152 name,
153 optional: matches!(state, HirRepState::Optional),
154 index: group_index,
155 });
156
157 let sub = process_hir_recurse(&capture.sub, groups, state)?;
158
159 Ok(Hir::capture(Capture {
160 index: group_index,
161 name: Some(name.into()),
162 sub: Box::new(sub),
163 }))
164 }
165
166 HirKind::Concat(ref concat) => concat
168 .iter()
169 .map(|sub| process_hir_recurse(sub, groups, state))
170 .collect::<Result<_, _>>()
171 .map(Hir::concat),
172
173 HirKind::Alternation(ref alt) => alt
177 .iter()
178 .map(|sub| process_hir_recurse(sub, groups, state.and(HirRepState::Optional)))
179 .collect::<Result<_, _>>()
180 .map(Hir::alternation),
181 }
182}
183
184fn process_hir(hir: &Hir) -> Result<(Hir, Vec<GroupInfo<'_>>), HirError> {
185 let mut groups = Vec::new();
186
187 process_hir_recurse(hir, &mut groups, HirRepState::Definite).map(|hir| (hir, groups))
188}
189
190fn regex_impl_result(input: &Request) -> Result<TokenStream2, syn::Error> {
191 let hir = parse_regex(&input.regex.value()).map_err(|error| {
192 syn::Error::new(
193 input.regex.span(),
194 lazy_format!("error compiling regex:\n{error}"),
195 )
196 })?;
197
198 let (hir, groups) =
199 process_hir(&hir).map_err(|error| syn::Error::new(input.regex.span(), error))?;
200
201 let _compiled_regex = Regex::builder().build_from_hir(&hir).map_err(|error| {
204 syn::Error::new(
205 input.regex.span(),
206 lazy_format!("error compiling regex:\n{error}"),
207 )
208 })?;
209
210 let public = input.public;
211 let type_name = &input.type_name;
212
213 let slots_ident = Ident::new("slots", type_name.span());
214 let haystack_ident = Ident::new("haystack", type_name.span());
215
216 let mod_name = format_ident!("Mod{type_name}");
217 let matches_type_name = format_ident!("{type_name}Captures");
218
219 let matches_fields_definitions = groups.iter().map(|&GroupInfo { name, optional, .. }| {
220 let type_name = match optional {
221 false => quote! { #CaptureType<'a> },
222 true => quote! { ::core::option::Option<#CaptureType<'a>> },
223 };
224
225 let field_name = format_ident!("{name}", span = type_name.span());
226
227 quote! { #field_name : #type_name }
228 });
229
230 let matches_field_populators = groups.iter().map(
231 |&GroupInfo {
232 name,
233 optional,
234 index,
235 }| {
236 let slot_start = (index as usize) * 2;
237 let slot_end = slot_start + 1;
238
239 let field_name = format_ident!("{name}", span = type_name.span());
240
241 let populate = quote! {{
242 let slot_start = #slots_ident[#slot_start];
243 let slot_end = #slots_ident[#slot_end];
244
245 match slot_start {
246 None => None,
247 Some(start) => {
248 let start = start.get();
249 let end = unsafe { slot_end.unwrap_unchecked() }.get();
250 let content = unsafe { #haystack_ident.get_unchecked(start..end) };
251
252 Some(#CaptureType {start, end, content})
253 }
254 }
255 }};
256
257 let expr = match optional {
258 true => populate,
259 false => quote! {
260 match #populate {
261 Some(capture) => capture,
262 None => unsafe { ::core::hint::unreachable_unchecked() },
263 }
264 },
265 };
266
267 quote! { #field_name : #expr }
268 },
269 );
270
271 let num_capture_groups = groups.len();
272
273 let captures_impl = (num_capture_groups > 0).then(|| quote! {
274 impl #type_name {
275 #[inline]
276 #[must_use]
277 pub fn captures<'i>(&self, #haystack_ident: &'i str) -> ::core::option::Option<#matches_type_name<'i>> {
278 let mut #slots_ident = [::core::option::Option::None; (#num_capture_groups + 1) * 2];
279 let _ = self.regex.search_slots(&#InputType::new(#haystack_ident), &mut #slots_ident)?;
280
281 ::core::option::Option::Some(#matches_type_name {
282 #(#matches_field_populators ,)*
283 })
284 }
285 }
286
287 #[derive(Debug, Clone, Copy)]
288 pub struct #matches_type_name<'a> {
289 #(pub #matches_fields_definitions,)*
290 }
291 });
292
293 let captures_export = captures_impl.is_some().then(|| {
294 quote! {
295 #public use #mod_name::#matches_type_name
296 }
297 });
298
299 let rendered_hir = hir_expression(&hir);
300
301 Ok(quote! {
302 #[doc(hidden)]
306 #[allow(non_snake_case)]
307 mod #mod_name {
308 #[derive(Debug, Clone)]
309 pub struct #type_name {
310 regex: #RegexType,
311 }
312
313 impl #type_name {
314 #[inline]
315 #[must_use]
316 pub fn new() -> Self {
317 let hir: #HirType = #rendered_hir;
318 let regex = #RegexType::builder()
319 .build_from_hir(&hir)
320 .expect("regex compilation failed, despite compile-time verification");
321 Self { regex }
322 }
323
324 #[inline]
325 #[must_use]
326 pub fn is_match(&self, haystack: &str) -> bool {
327 self.regex.is_match(haystack)
328 }
329
330 #[inline]
331 #[must_use]
332 pub fn find<'i>(&self, haystack: &'i str) -> ::core::option::Option<#CaptureType<'i>> {
333 let capture = self.regex.find(haystack)?;
334 let span = capture.span();
335
336 let start = span.start;
337 let end = span.end;
338 let content = unsafe { haystack.get_unchecked(start..end) };
339
340 Some(#CaptureType { start, end, content })
341 }
342 }
343
344 impl ::core::default::Default for #type_name {
345 fn default() -> Self {
346 Self::new()
347 }
348 }
349
350 #captures_impl
351 }
352
353 #public use #mod_name::#type_name;
354 #captures_export;
355
356 })
357}
358
359#[proc_macro]
360pub fn regex_impl(input: TokenStream) -> TokenStream {
361 let input = parse_macro_input!(input as Request);
362
363 regex_impl_result(&input)
364 .unwrap_or_else(|error| error.into_compile_error())
365 .into()
366}