ctreg_macro/
lib.rs

1/*!
2Implementation of the proc macro for `ctreg`. You should never use this crate
3directly.
4 */
5
6mod render;
7
8extern crate proc_macro;
9use proc_macro::TokenStream;
10
11use lazy_format::lazy_format;
12use proc_macro2::TokenStream as TokenStream2;
13use quote::{format_ident, quote};
14use regex_automata::meta::Regex;
15use regex_syntax::{
16    hir::{self, Capture, Hir, HirKind, Repetition},
17    parse as parse_regex,
18};
19use render::hir_expression;
20use syn::{
21    parse::{Parse, ParseStream},
22    parse_macro_input,
23    spanned::Spanned,
24    Ident, Token,
25};
26use thiserror::Error;
27
28use self::render::{CaptureType, HirType, InputType, RegexType};
29
30struct Request {
31    public: Option<Token![pub]>,
32    type_name: syn::Ident,
33    regex: syn::LitStr,
34}
35
36impl Parse for Request {
37    fn parse(input: ParseStream) -> syn::Result<Self> {
38        let public = input.parse()?;
39        let type_name = input.parse()?;
40        let _eq: Token![=] = input.parse()?;
41        let regex = input.parse()?;
42
43        Ok(Self {
44            public,
45            type_name,
46            regex,
47        })
48    }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
52enum HirRepState {
53    Definite,
54    Optional,
55    Repeating,
56}
57
58impl HirRepState {
59    fn from_reps(repetition: &Repetition) -> Self {
60        match (repetition.min, repetition.max) {
61            (1, Some(1)) => Self::Definite,
62            (0, Some(1)) => Self::Optional,
63            _ => Self::Repeating,
64        }
65    }
66
67    fn and(self, other: HirRepState) -> Self {
68        Ord::max(self, other)
69    }
70
71    fn with(self, repetition: &Repetition) -> Self {
72        self.and(Self::from_reps(repetition))
73    }
74}
75
76#[derive(Debug, Clone, Copy)]
77struct GroupInfo<'a> {
78    name: &'a str,
79    optional: bool,
80    index: u32,
81}
82
83fn get_group_index(groups: &[GroupInfo<'_>]) -> u32 {
84    groups.last().map(|group| group.index).unwrap_or(0) + 1
85}
86
87#[derive(Debug, Error)]
88enum HirError {
89    #[error("duplicate group name: {0:?}")]
90    DuplicateGroupName(String),
91
92    #[error("capture group {0:?} is repeating; capture groups can't repeat")]
93    RepeatingCaptureGroup(String),
94
95    #[error("capture group name {0:?} is not a valid rust identifier")]
96    BadName(String),
97}
98
99/// Analyze and rewrite the syntax tree
100///
101/// - Collect information about the capture groups we'll be using
102/// - Erase anonymous capture groups
103fn process_hir_recurse<'a>(
104    hir: &'a Hir,
105    groups: &mut Vec<GroupInfo<'a>>,
106    state: HirRepState,
107) -> Result<Hir, HirError> {
108    match *hir.kind() {
109        // Literals and their equivalents are passed verbatim
110        HirKind::Empty => Ok(Hir::empty()),
111        HirKind::Literal(hir::Literal(ref lit)) => Ok(Hir::literal(lit.clone())),
112        HirKind::Class(ref class) => Ok(Hir::class(class.clone())),
113        HirKind::Look(look) => Ok(Hir::look(look)),
114
115        // Need to compute the repetition state for repetitions
116        HirKind::Repetition(ref repetition) => {
117            let state = state.with(repetition);
118            let sub = process_hir_recurse(&repetition.sub, groups, state)?;
119
120            Ok(Hir::repetition(Repetition {
121                sub: Box::new(sub),
122                ..*repetition
123            }))
124        }
125
126        // Capture groups are the most complicated. Need to remove anonymous
127        // groups, renumber other groups, and check repetition / optional states.
128        HirKind::Capture(ref capture) => {
129            let Some(name) = capture.name.as_deref() else {
130                // Anonymous groups don't capture in ctreg
131                return process_hir_recurse(&capture.sub, groups, state);
132            };
133
134            // Let syn do the work for us of validating that this is a correct
135            // rust identifier
136            let _ident: Ident =
137                syn::parse_str(name).map_err(|_| HirError::BadName(name.to_owned()))?;
138
139            // Check duplicate groups
140            if groups.iter().any(|group| group.name == name) {
141                return Err(HirError::DuplicateGroupName(name.to_owned()));
142            }
143
144            // Check repeating groups
145            if state == HirRepState::Repeating {
146                return Err(HirError::RepeatingCaptureGroup(name.to_owned()));
147            }
148
149            let group_index = get_group_index(groups);
150
151            groups.push(GroupInfo {
152                name,
153                optional: matches!(state, HirRepState::Optional),
154                index: group_index,
155            });
156
157            let sub = process_hir_recurse(&capture.sub, groups, state)?;
158
159            Ok(Hir::capture(Capture {
160                index: group_index,
161                name: Some(name.into()),
162                sub: Box::new(sub),
163            }))
164        }
165
166        // Concatenations are trivial
167        HirKind::Concat(ref concat) => concat
168            .iter()
169            .map(|sub| process_hir_recurse(sub, groups, state))
170            .collect::<Result<_, _>>()
171            .map(Hir::concat),
172
173        // regex syntax guarantees that alternations have at least 2 variants,
174        // so each one is unconditionally optional. In the future we could
175        // produce an enum, to reflect that at least one variant will exist
176        HirKind::Alternation(ref alt) => alt
177            .iter()
178            .map(|sub| process_hir_recurse(sub, groups, state.and(HirRepState::Optional)))
179            .collect::<Result<_, _>>()
180            .map(Hir::alternation),
181    }
182}
183
184fn process_hir(hir: &Hir) -> Result<(Hir, Vec<GroupInfo<'_>>), HirError> {
185    let mut groups = Vec::new();
186
187    process_hir_recurse(hir, &mut groups, HirRepState::Definite).map(|hir| (hir, groups))
188}
189
190fn regex_impl_result(input: &Request) -> Result<TokenStream2, syn::Error> {
191    let hir = parse_regex(&input.regex.value()).map_err(|error| {
192        syn::Error::new(
193            input.regex.span(),
194            lazy_format!("error compiling regex:\n{error}"),
195        )
196    })?;
197
198    let (hir, groups) =
199        process_hir(&hir).map_err(|error| syn::Error::new(input.regex.span(), error))?;
200
201    // We don't actually use the compiled regex for anything, we just need to
202    // ensure that the `hir` does compile correctly.
203    let _compiled_regex = Regex::builder().build_from_hir(&hir).map_err(|error| {
204        syn::Error::new(
205            input.regex.span(),
206            lazy_format!("error compiling regex:\n{error}"),
207        )
208    })?;
209
210    let public = input.public;
211    let type_name = &input.type_name;
212
213    let slots_ident = Ident::new("slots", type_name.span());
214    let haystack_ident = Ident::new("haystack", type_name.span());
215
216    let mod_name = format_ident!("Mod{type_name}");
217    let matches_type_name = format_ident!("{type_name}Captures");
218
219    let matches_fields_definitions = groups.iter().map(|&GroupInfo { name, optional, .. }| {
220        let type_name = match optional {
221            false => quote! { #CaptureType<'a> },
222            true => quote! { ::core::option::Option<#CaptureType<'a>> },
223        };
224
225        let field_name = format_ident!("{name}", span = type_name.span());
226
227        quote! { #field_name : #type_name }
228    });
229
230    let matches_field_populators = groups.iter().map(
231        |&GroupInfo {
232             name,
233             optional,
234             index,
235         }| {
236            let slot_start = (index as usize) * 2;
237            let slot_end = slot_start + 1;
238
239            let field_name = format_ident!("{name}", span = type_name.span());
240
241            let populate = quote! {{
242                let slot_start = #slots_ident[#slot_start];
243                let slot_end = #slots_ident[#slot_end];
244
245                match slot_start {
246                    None => None,
247                    Some(start) => {
248                        let start = start.get();
249                        let end = unsafe { slot_end.unwrap_unchecked() }.get();
250                        let content = unsafe { #haystack_ident.get_unchecked(start..end) };
251
252                        Some(#CaptureType {start, end, content})
253                    }
254                }
255            }};
256
257            let expr = match optional {
258                true => populate,
259                false => quote! {
260                    match #populate {
261                        Some(capture) => capture,
262                        None => unsafe { ::core::hint::unreachable_unchecked() },
263                    }
264                },
265            };
266
267            quote! { #field_name : #expr }
268        },
269    );
270
271    let num_capture_groups = groups.len();
272
273    let captures_impl = (num_capture_groups > 0).then(|| quote! {
274        impl #type_name {
275            #[inline]
276            #[must_use]
277            pub fn captures<'i>(&self, #haystack_ident: &'i str) -> ::core::option::Option<#matches_type_name<'i>> {
278                let mut #slots_ident = [::core::option::Option::None; (#num_capture_groups + 1) * 2];
279                let _ = self.regex.search_slots(&#InputType::new(#haystack_ident), &mut #slots_ident)?;
280
281                ::core::option::Option::Some(#matches_type_name {
282                    #(#matches_field_populators ,)*
283                })
284            }
285        }
286
287        #[derive(Debug, Clone, Copy)]
288        pub struct #matches_type_name<'a> {
289            #(pub #matches_fields_definitions,)*
290        }
291    });
292
293    let captures_export = captures_impl.is_some().then(|| {
294        quote! {
295            #public use #mod_name::#matches_type_name
296        }
297    });
298
299    let rendered_hir = hir_expression(&hir);
300
301    Ok(quote! {
302        // The implementations are put into a submodule to ensure that the
303        // caller of the regex macro doesn't have access to the internals
304        // of these types
305        #[doc(hidden)]
306        #[allow(non_snake_case)]
307        mod #mod_name {
308            #[derive(Debug, Clone)]
309            pub struct #type_name {
310                regex: #RegexType,
311            }
312
313            impl #type_name {
314                #[inline]
315                #[must_use]
316                pub fn new() -> Self {
317                    let hir: #HirType = #rendered_hir;
318                    let regex = #RegexType::builder()
319                        .build_from_hir(&hir)
320                        .expect("regex compilation failed, despite compile-time verification");
321                    Self { regex }
322                }
323
324                #[inline]
325                #[must_use]
326                pub fn is_match(&self, haystack: &str) -> bool {
327                    self.regex.is_match(haystack)
328                }
329
330                #[inline]
331                #[must_use]
332                pub fn find<'i>(&self, haystack: &'i str) -> ::core::option::Option<#CaptureType<'i>> {
333                    let capture = self.regex.find(haystack)?;
334                    let span = capture.span();
335
336                    let start = span.start;
337                    let end = span.end;
338                    let content = unsafe { haystack.get_unchecked(start..end) };
339
340                    Some(#CaptureType { start, end, content })
341                }
342            }
343
344            impl ::core::default::Default for #type_name {
345                fn default() -> Self {
346                    Self::new()
347                }
348            }
349
350            #captures_impl
351        }
352
353        #public use #mod_name::#type_name;
354        #captures_export;
355
356    })
357}
358
359#[proc_macro]
360pub fn regex_impl(input: TokenStream) -> TokenStream {
361    let input = parse_macro_input!(input as Request);
362
363    regex_impl_result(&input)
364        .unwrap_or_else(|error| error.into_compile_error())
365        .into()
366}