litcheck_filecheck/pattern/matcher/matchers/
regex_set.rs

1use regex_automata::{
2    dfa::{self, dense, onepass, Automaton, OverlappingState, StartKind},
3    meta,
4    nfa::thompson,
5    util::{
6        captures::{Captures, GroupInfo},
7        syntax,
8    },
9    Anchored, MatchKind, PatternID,
10};
11
12use crate::{
13    ast::{Capture, RegexPattern},
14    common::*,
15    pattern::matcher::regex,
16};
17
18#[derive(Default, Clone)]
19#[allow(clippy::large_enum_variant)]
20pub enum CapturingRegex {
21    #[default]
22    None,
23    Onepass {
24        re: onepass::DFA,
25        cache: onepass::Cache,
26    },
27    Default {
28        re: meta::Regex,
29        cache: meta::Cache,
30    },
31}
32
33struct Search<'input> {
34    crlf: bool,
35    forward: bool,
36    input: regex_automata::Input<'input>,
37    reverse_input: regex_automata::Input<'input>,
38    last_match_end: Option<usize>,
39    captures: Captures,
40    overlapping: OverlappingState,
41    overlapping_reverse: OverlappingState,
42}
43impl<'input> Search<'input> {
44    fn start(input: Input<'input>, captures: Captures) -> Self {
45        let crlf = input.is_crlf();
46        let input: regex_automata::Input<'input> = input.into();
47        let reverse_input = input.clone().earliest(false);
48        let overlapping = OverlappingState::start();
49        let overlapping_reverse = OverlappingState::start();
50        Self {
51            crlf,
52            forward: true,
53            input,
54            reverse_input,
55            last_match_end: None,
56            overlapping,
57            overlapping_reverse,
58            captures,
59        }
60    }
61}
62
63pub struct RegexSetSearcher<'a, 'input, A = dense::DFA<Vec<u32>>> {
64    search: Search<'input>,
65    /// The set of raw input patterns from which
66    /// this matcher was constructed
67    patterns: Vec<Span<Cow<'a, str>>>,
68    /// The compiled regex which will be used to search the input buffer
69    regex: dfa::regex::Regex<A>,
70    /// The regex used to obtain capturing groups, if there are any
71    capturing_regex: CapturingRegex,
72    /// Metadata about captures in the given patterns
73    ///
74    /// Each pattern gets its own vector of capture info, since
75    /// there is no requirement that all patterns have the same
76    /// number or type of captures
77    capture_types: Vec<Vec<Capture>>,
78}
79impl<'a, 'input> RegexSetSearcher<'a, 'input> {
80    pub fn new(
81        input: Input<'input>,
82        patterns: Vec<RegexPattern<'a>>,
83        config: &Config,
84        interner: &StringInterner,
85    ) -> DiagResult<Self> {
86        let start_kind = if input.is_anchored() {
87            StartKind::Anchored
88        } else {
89            StartKind::Unanchored
90        };
91        Ok(
92            RegexSetMatcher::new_with_start_kind(start_kind, patterns, config, interner)?
93                .into_searcher(input),
94        )
95    }
96}
97impl<'a, 'input, A: Automaton> RegexSetSearcher<'a, 'input, A> {
98    pub fn from_matcher(matcher: RegexSetMatcher<'a, A>, input: Input<'input>) -> Self {
99        let captures = matcher.captures;
100        let search = Search::start(input, captures);
101        Self {
102            search,
103            patterns: matcher.patterns,
104            regex: matcher.regex,
105            capturing_regex: matcher.capturing_regex,
106            capture_types: matcher.capture_types,
107        }
108    }
109}
110
111impl<'a, 'input, A: Automaton + Clone> RegexSetSearcher<'a, 'input, A> {
112    pub fn from_matcher_ref(matcher: &RegexSetMatcher<'a, A>, input: Input<'input>) -> Self {
113        let captures = matcher.captures.clone();
114        let search = Search::start(input, captures);
115        Self {
116            search,
117            patterns: matcher.patterns.clone(),
118            regex: matcher.regex.clone(),
119            capturing_regex: matcher.capturing_regex.clone(),
120            capture_types: matcher.capture_types.clone(),
121        }
122    }
123}
124impl<'a, 'input, A> fmt::Debug for RegexSetSearcher<'a, 'input, A> {
125    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
126        f.debug_struct("RegexSetSearcher")
127            .field("patterns", &self.patterns)
128            .field("capture_types", &self.search.captures)
129            .field("crlf", &self.search.crlf)
130            .field("forward", &self.search.forward)
131            .field("last_match_end", &self.search.last_match_end)
132            .finish()
133    }
134}
135impl<'a, 'input, A> Spanned for RegexSetSearcher<'a, 'input, A> {
136    fn span(&self) -> SourceSpan {
137        let start = self.patterns.iter().map(|p| p.start()).min().unwrap();
138        let end = self.patterns.iter().map(|p| p.end()).max().unwrap();
139        SourceSpan::from(start..end)
140    }
141}
142
143pub struct RegexSetMatcher<'a, A = dense::DFA<Vec<u32>>> {
144    /// The set of raw input patterns from which
145    /// this matcher was constructed
146    patterns: Vec<Span<Cow<'a, str>>>,
147    /// The compiled regex which will be used to search the input buffer
148    regex: dfa::regex::Regex<A>,
149    /// The regex used to obtain capturing groups, if there are any
150    capturing_regex: CapturingRegex,
151    /// Captures storage
152    captures: Captures,
153    /// Metadata about captures in the given patterns
154    ///
155    /// Each pattern gets its own vector of capture info, since
156    /// there is no requirement that all patterns have the same
157    /// number or type of captures
158    capture_types: Vec<Vec<Capture>>,
159}
160impl<'a> RegexSetMatcher<'a> {
161    pub fn new(
162        patterns: Vec<RegexPattern<'a>>,
163        config: &Config,
164        interner: &StringInterner,
165    ) -> DiagResult<Self> {
166        Self::new_with_start_kind(StartKind::Both, patterns, config, interner)
167    }
168
169    pub fn new_with_start_kind(
170        start_kind: StartKind,
171        patterns: Vec<RegexPattern<'a>>,
172        config: &Config,
173        interner: &StringInterner,
174    ) -> DiagResult<Self> {
175        let regex = dfa::regex::Regex::builder()
176            .dense(
177                dense::Config::new()
178                    .match_kind(MatchKind::All)
179                    .start_kind(start_kind)
180                    .starts_for_each_pattern(true),
181            )
182            .syntax(
183                syntax::Config::new()
184                    .multi_line(true)
185                    .case_insensitive(config.ignore_case),
186            )
187            .build_many(&patterns)
188            .map_err(|error| {
189                regex::build_error_to_diagnostic(error, patterns.len(), |id| patterns[id].span())
190            })?;
191
192        let has_captures = patterns.iter().any(|p| !p.captures.is_empty());
193        let (capturing_regex, captures) = if !has_captures {
194            (CapturingRegex::None, Captures::empty(GroupInfo::empty()))
195        } else {
196            onepass::DFA::builder()
197                .syntax(
198                    syntax::Config::new()
199                        .utf8(false)
200                        .multi_line(true)
201                        .case_insensitive(config.ignore_case),
202                )
203                .thompson(thompson::Config::new().utf8(false))
204                .configure(onepass::Config::new().starts_for_each_pattern(true))
205                .build_many(&patterns)
206                .map_or_else(
207                    |_| {
208                        let re = Regex::builder()
209                            .configure(Regex::config().match_kind(MatchKind::All))
210                            .syntax(
211                                syntax::Config::new()
212                                    .multi_line(true)
213                                    .case_insensitive(config.ignore_case),
214                            )
215                            .build_many(&patterns)
216                            .unwrap();
217                        let cache = re.create_cache();
218                        let captures = re.create_captures();
219                        (CapturingRegex::Default { re, cache }, captures)
220                    },
221                    |re| {
222                        let cache = re.create_cache();
223                        let captures = re.create_captures();
224                        (CapturingRegex::Onepass { re, cache }, captures)
225                    },
226                )
227        };
228
229        // Compute capture group information
230        let mut capture_types = vec![vec![]; patterns.len()];
231        let mut strings = Vec::with_capacity(patterns.len());
232        let groups = captures.group_info();
233        for (
234            i,
235            RegexPattern {
236                pattern,
237                captures: pattern_captures,
238            },
239        ) in patterns.into_iter().enumerate()
240        {
241            let span = pattern.span();
242            strings.push(pattern);
243            let pid = PatternID::new_unchecked(i);
244            let num_captures = groups.group_len(pid);
245            capture_types[i].resize(num_captures, Capture::Ignore(span));
246            for capture in pattern_captures.into_iter() {
247                if let Capture::Ignore(_) = capture {
248                    continue;
249                }
250                if let Some(name) = capture.group_name() {
251                    let group_name = interner.resolve(name);
252                    let group_id = groups.to_index(pid, group_name).unwrap_or_else(|| {
253                        panic!("expected group for capture of '{group_name}' in pattern {i}")
254                    });
255                    capture_types[i][group_id] = capture;
256                } else {
257                    assert_eq!(
258                        &capture_types[i][0],
259                        &Capture::Ignore(span),
260                        "{capture:?} would overwrite a previous implicit capture group in pattern {i}"
261                    );
262                    capture_types[i][0] = capture;
263                }
264            }
265        }
266
267        Ok(Self {
268            patterns: strings,
269            regex,
270            capturing_regex,
271            captures,
272            capture_types,
273        })
274    }
275
276    pub fn patterns_len(&self) -> usize {
277        self.patterns.len()
278    }
279
280    pub fn first_pattern(&self) -> Span<usize> {
281        self.patterns
282            .iter()
283            .enumerate()
284            .map(|(i, p)| Span::new(p.span(), i))
285            .min_by_key(|span| span.start())
286            .unwrap()
287    }
288
289    pub fn first_pattern_span(&self) -> SourceSpan {
290        self.first_pattern().span()
291    }
292}
293impl<'a, A: Automaton + Clone> RegexSetMatcher<'a, A> {
294    pub fn search<'input>(&self, input: Input<'input>) -> RegexSetSearcher<'a, 'input, A> {
295        RegexSetSearcher::from_matcher_ref(self, input)
296    }
297}
298impl<'a, A: Automaton> RegexSetMatcher<'a, A> {
299    #[inline]
300    pub fn into_searcher<'input>(self, input: Input<'input>) -> RegexSetSearcher<'a, 'input, A> {
301        RegexSetSearcher::from_matcher(self, input)
302    }
303}
304impl<'a, A> fmt::Debug for RegexSetMatcher<'a, A> {
305    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
306        f.debug_struct("RegexSetMatcher")
307            .field("patterns", &self.patterns)
308            .field("captures", &self.captures)
309            .field("capture_types", &self.capture_types)
310            .finish()
311    }
312}
313impl<'a, A> Spanned for RegexSetMatcher<'a, A> {
314    fn span(&self) -> SourceSpan {
315        let start = self.patterns.iter().map(|p| p.start()).min().unwrap();
316        let end = self.patterns.iter().map(|p| p.end()).max().unwrap();
317        SourceSpan::from(start..end)
318    }
319}
320impl<'a, A: Automaton + Clone> MatcherMut for RegexSetMatcher<'a, A> {
321    fn try_match_mut<'input, 'context, C>(
322        &self,
323        input: Input<'input>,
324        context: &mut C,
325    ) -> DiagResult<MatchResult<'input>>
326    where
327        C: Context<'input, 'context> + ?Sized,
328    {
329        let mut searcher = self.search(input);
330        searcher.try_match_next(context)
331    }
332}
333impl<'a, 'input, A> PatternSearcher<'input> for RegexSetSearcher<'a, 'input, A>
334where
335    A: Automaton,
336{
337    type Input = regex_automata::Input<'input>;
338    type PatternID = PatternID;
339
340    fn input(&self) -> &Self::Input {
341        &self.search.input
342    }
343    fn last_match_end(&self) -> Option<usize> {
344        self.search.last_match_end
345    }
346    fn set_last_match_end(&mut self, end: usize) {
347        self.search.last_match_end = Some(end);
348        self.search.input.set_start(end);
349    }
350    fn patterns_len(&self) -> usize {
351        self.patterns.len()
352    }
353    fn pattern_span(&self, id: Self::PatternID) -> SourceSpan {
354        self.patterns[id.as_usize()].span()
355    }
356    fn try_match_next<'context, C>(&mut self, context: &mut C) -> DiagResult<MatchResult<'input>>
357    where
358        C: Context<'input, 'context> + ?Sized,
359    {
360        let (fwd_dfa, rev_dfa) = (self.regex.forward(), self.regex.reverse());
361        let matched;
362        let mut last_end = self
363            .search
364            .last_match_end
365            .unwrap_or(self.search.input.start());
366        loop {
367            if self.search.forward {
368                if let Some((pattern_id, end)) = {
369                    fwd_dfa
370                        .try_search_overlapping_fwd(
371                            &self.search.input,
372                            &mut self.search.overlapping,
373                        )
374                        .expect("match error");
375                    self.search
376                        .overlapping
377                        .get_match()
378                        .map(|hm| (hm.pattern(), hm.offset()))
379                } {
380                    last_end = end;
381                    self.search
382                        .reverse_input
383                        .set_anchored(Anchored::Pattern(pattern_id));
384                    self.search
385                        .reverse_input
386                        .set_range(self.search.input.start()..end);
387                    self.search.forward = false;
388                    self.search.overlapping_reverse = OverlappingState::start();
389                    continue;
390                } else {
391                    matched = None;
392                    break;
393                }
394            } else if let Some((pattern_id, start)) = {
395                rev_dfa
396                    .try_search_overlapping_rev(
397                        &self.search.reverse_input,
398                        &mut self.search.overlapping_reverse,
399                    )
400                    .expect("match error");
401                self.search
402                    .overlapping_reverse
403                    .get_match()
404                    .map(|hm| (hm.pattern(), hm.offset()))
405            } {
406                if start == last_end && !self.search.input.is_char_boundary(last_end) {
407                    continue;
408                }
409                self.search.last_match_end = Some(last_end);
410                matched = Some(regex_automata::Match::new(pattern_id, start..last_end));
411                break;
412            } else {
413                self.search.forward = true;
414            }
415        }
416
417        if let Some(matched) = matched {
418            self.search.captures.clear();
419            let pattern_id = matched.pattern();
420            match self.capturing_regex {
421                CapturingRegex::None => {
422                    let overall_span = SourceSpan::from(matched.range());
423                    let pattern_index = pattern_id.as_usize();
424                    let pattern_span = self.patterns[pattern_index].span();
425                    Ok(MatchResult::ok(MatchInfo {
426                        span: overall_span,
427                        pattern_span,
428                        pattern_id: pattern_index,
429                        captures: vec![],
430                    }))
431                }
432                CapturingRegex::Default {
433                    ref re,
434                    ref mut cache,
435                } => {
436                    let input = self
437                        .search
438                        .input
439                        .clone()
440                        .anchored(Anchored::Pattern(pattern_id))
441                        .range(matched.range());
442                    re.search_captures_with(cache, &input, &mut self.search.captures);
443                    if let Some(matched) = self.search.captures.get_match() {
444                        extract_captures_from_match(
445                            matched,
446                            &self.search,
447                            &self.patterns,
448                            &self.capture_types,
449                            context,
450                        )
451                    } else {
452                        let error = CheckFailedError::MatchError {
453                            span: SourceSpan::from(matched.range()),
454                            input_file: context.input_file(),
455                            labels: vec![RelatedLabel::note(Label::at(matched.range()), context.match_file())],
456                            help: Some("meta regex searcher failed to match the input even though an initial DFA pass found a match".to_string()),
457                        };
458                        Err(Report::new(error))
459                    }
460                }
461                CapturingRegex::Onepass {
462                    ref re,
463                    ref mut cache,
464                } => {
465                    let input = self
466                        .search
467                        .input
468                        .clone()
469                        .anchored(Anchored::Pattern(pattern_id))
470                        .range(matched.range());
471                    re.captures(cache, input, &mut self.search.captures);
472                    if let Some(matched) = self.search.captures.get_match() {
473                        extract_captures_from_match(
474                            matched,
475                            &self.search,
476                            &self.patterns,
477                            &self.capture_types,
478                            context,
479                        )
480                    } else {
481                        let error = CheckFailedError::MatchError {
482                            span: SourceSpan::from(matched.range()),
483                            input_file: context.input_file(),
484                            labels: vec![RelatedLabel::note(Label::at(matched.range()), context.match_file())],
485                            help: Some("onepass regex searcher failed to match the input even though an initial DFA pass found a match".to_string()),
486                        };
487                        Err(Report::new(error))
488                    }
489                }
490            }
491        } else {
492            Ok(MatchResult::failed(
493                CheckFailedError::MatchNoneButExpected {
494                    span: self.span(),
495                    match_file: context.match_file(),
496                    note: None,
497                },
498            ))
499        }
500    }
501}
502
503fn extract_captures_from_match<'a, 'input, 'context, C>(
504    matched: regex_automata::Match,
505    search: &Search<'_>,
506    patterns: &[Span<Cow<'a, str>>],
507    capture_types: &[Vec<Capture>],
508    context: &C,
509) -> DiagResult<MatchResult<'input>>
510where
511    C: Context<'input, 'context> + ?Sized,
512{
513    let pattern_id = matched.pattern();
514    let pattern_index = pattern_id.as_usize();
515    let pattern_span = patterns[pattern_index].span();
516    let overall_span = SourceSpan::from(matched.range());
517    let mut capture_infos = Vec::with_capacity(search.captures.group_len());
518    for (index, (maybe_capture_span, capture)) in search
519        .captures
520        .iter()
521        .zip(capture_types[pattern_id].iter().copied())
522        .enumerate()
523    {
524        if let Some(capture_span) = maybe_capture_span {
525            let input = context.search();
526            let captured = input.as_str(capture_span.range());
527            let capture_span = SourceSpan::from(capture_span.range());
528            let result = regex::try_convert_capture_to_type(
529                pattern_id,
530                index,
531                pattern_span,
532                overall_span,
533                Span::new(capture_span, captured),
534                capture,
535                &search.captures,
536                context,
537            );
538            match result {
539                Ok(capture_info) => {
540                    capture_infos.push(capture_info);
541                }
542                Err(error) => {
543                    return Ok(MatchResult::failed(error));
544                }
545            }
546        }
547    }
548    Ok(MatchResult::ok(MatchInfo {
549        span: overall_span,
550        pattern_span,
551        pattern_id: pattern_index,
552        captures: capture_infos,
553    }))
554}