litcheck_filecheck/pattern/matcher/matchers/
substring_set.rs

1use aho_corasick::{AhoCorasick, AhoCorasickBuilder, AhoCorasickKind, MatchKind, StartKind};
2
3use crate::{common::*, pattern::search::SubstringSetSearcher};
4
5/// This matcher is a variation on [SubstringMatcher] that searches
6/// for a match of any of multiple substrings in the input buffer.
7///
8/// This is much more efficient than performing multiple independent searches
9/// with [SubstringMatcher], so should be used whenever multiple substring
10/// patterns could be matched at the same time.
11pub struct SubstringSetMatcher<'a> {
12    /// The set of patterns to be matched
13    patterns: Vec<Span<Cow<'a, str>>>,
14    /// The automaton that will perform the search
15    searcher: AhoCorasick,
16}
17impl<'a> fmt::Debug for SubstringSetMatcher<'a> {
18    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
19        f.debug_struct("SubstringSetMatcher")
20            .field("patterns", &self.patterns)
21            .field("kind", &self.searcher.kind())
22            .field("start_kind", &self.searcher.start_kind())
23            .field("match_kind", &self.searcher.match_kind())
24            .finish()
25    }
26}
27impl<'a> SubstringSetMatcher<'a> {
28    /// Create a new matcher for the given set of substring patterns
29    ///
30    /// NOTE: This function will panic if the set is empty.
31    pub fn new(patterns: Vec<Span<Cow<'a, str>>>, config: &Config) -> DiagResult<Self> {
32        let patterns = patterns
33            .into_iter()
34            .map(|p| {
35                p.map(|p| text::canonicalize_horizontal_whitespace(p, config.strict_whitespace))
36            })
37            .collect();
38
39        let mut builder = SubstringSetBuilder::new_with_patterns(patterns);
40        builder.case_insensitive(config.ignore_case);
41        builder.build()
42    }
43
44    pub fn search<'input, 'patterns>(
45        &'patterns self,
46        input: Input<'input>,
47    ) -> DiagResult<SubstringSetSearcher<'a, 'patterns, 'input>> {
48        SubstringSetSearcher::new(input, Cow::Borrowed(&self.patterns))
49    }
50
51    pub fn pattern_len(&self) -> usize {
52        self.patterns.len()
53    }
54
55    pub fn first_pattern(&self) -> Span<usize> {
56        self.patterns
57            .iter()
58            .enumerate()
59            .map(|(i, p)| Span::new(p.span(), i))
60            .min_by_key(|span| span.start())
61            .unwrap()
62    }
63
64    pub fn first_pattern_span(&self) -> SourceSpan {
65        self.first_pattern().span()
66    }
67
68    /// Get a builder for configuring and building a new [SubstringSetMatcher]
69    pub fn build() -> SubstringSetBuilder<'a> {
70        SubstringSetBuilder::default()
71    }
72
73    /// Search for all of the non-overlapping matches in the input
74    pub fn try_match_all<'input>(&self, input: Input<'input>) -> Vec<MatchInfo<'input>> {
75        let mut matches = vec![];
76        for matched in self.searcher.find_iter(input) {
77            let pattern_id = matched.pattern().as_usize();
78            let pattern_span = self.patterns[pattern_id].span();
79            let span = SourceSpan::from(matched.range());
80            matches.push(MatchInfo::new_with_pattern(span, pattern_span, pattern_id))
81        }
82        matches
83    }
84
85    /// Search for all of the matches in the input, including overlapping matches
86    pub fn try_match_overlapping<'input>(&self, input: Input<'input>) -> Vec<MatchInfo<'input>> {
87        let mut matches = vec![];
88        for matched in self.searcher.find_overlapping_iter(input) {
89            let pattern_id = matched.pattern().as_usize();
90            let pattern_span = self.patterns[pattern_id].span();
91            let span = SourceSpan::from(matched.range());
92            matches.push(MatchInfo::new_with_pattern(span, pattern_span, pattern_id))
93        }
94        matches
95    }
96}
97impl<'a> Spanned for SubstringSetMatcher<'a> {
98    fn span(&self) -> SourceSpan {
99        let start = self.patterns.iter().map(|p| p.start()).min().unwrap();
100        let end = self.patterns.iter().map(|p| p.end()).max().unwrap();
101        SourceSpan::from(start..end)
102    }
103}
104impl<'a> MatcherMut for SubstringSetMatcher<'a> {
105    fn try_match_mut<'input, 'context, C>(
106        &self,
107        input: Input<'input>,
108        context: &mut C,
109    ) -> DiagResult<MatchResult<'input>>
110    where
111        C: Context<'input, 'context> + ?Sized,
112    {
113        self.try_match(input, context)
114    }
115}
116impl<'a> Matcher for SubstringSetMatcher<'a> {
117    fn try_match<'input, 'context, C>(
118        &self,
119        input: Input<'input>,
120        context: &C,
121    ) -> DiagResult<MatchResult<'input>>
122    where
123        C: Context<'input, 'context> + ?Sized,
124    {
125        if let Some(matched) = self.searcher.find(input) {
126            let pattern_id = matched.pattern().as_usize();
127            let pattern_span = self.patterns[pattern_id].span();
128            Ok(MatchResult::ok(MatchInfo::new_with_pattern(
129                matched.range(),
130                pattern_span,
131                pattern_id,
132            )))
133        } else {
134            Ok(MatchResult::failed(
135                CheckFailedError::MatchNoneButExpected {
136                    span: self.span(),
137                    match_file: context.match_file(),
138                    note: None,
139                },
140            ))
141        }
142    }
143}
144
145pub struct SubstringSetBuilder<'a> {
146    patterns: Vec<Span<Cow<'a, str>>>,
147    start_kind: Option<StartKind>,
148    match_kind: Option<MatchKind>,
149    case_insensitive: bool,
150    support_overlapping_matches: bool,
151}
152impl<'a> Default for SubstringSetBuilder<'a> {
153    #[inline]
154    fn default() -> Self {
155        Self::new_with_patterns(vec![])
156    }
157}
158impl<'a> SubstringSetBuilder<'a> {
159    #[inline]
160    pub fn new() -> Self {
161        Self {
162            patterns: vec![],
163            start_kind: None,
164            match_kind: None,
165            case_insensitive: false,
166            support_overlapping_matches: false,
167        }
168    }
169
170    #[inline]
171    pub fn new_with_patterns(patterns: Vec<Span<Cow<'a, str>>>) -> Self {
172        Self {
173            patterns,
174            start_kind: None,
175            match_kind: None,
176            case_insensitive: false,
177            support_overlapping_matches: false,
178        }
179    }
180
181    /// Add `pattern` to the set of substrings to match
182    pub fn with_pattern(&mut self, pattern: Span<Cow<'a, str>>) -> &mut Self {
183        self.patterns.push(pattern);
184        self
185    }
186
187    /// Add `patterns` to the set of substrings to match
188    pub fn with_patterns<I>(&mut self, patterns: I) -> &mut Self
189    where
190        I: IntoIterator<Item = Span<Cow<'a, str>>>,
191    {
192        self.patterns.extend(patterns);
193        self
194    }
195
196    /// Set whether or not the matcher will support anchored searches.
197    ///
198    /// Since supporting anchored searches can be significantly more expensive,
199    /// you should only do so if you need such searches.
200    pub fn support_anchored_search(&mut self, yes: bool) -> &mut Self {
201        self.start_kind = if yes {
202            Some(StartKind::Both)
203        } else {
204            Some(StartKind::Unanchored)
205        };
206        self
207    }
208
209    /// Set whether or not the matcher will support asking for overlapping matches
210    ///
211    /// NOTE: This will force the [MatchKind] to `MatchKind::Standard`, as other semantics
212    /// are not support when computing overlapping matches.
213    pub fn support_overlapping_matches(&mut self, yes: bool) -> &mut Self {
214        self.support_overlapping_matches = yes;
215        self.match_kind = Some(MatchKind::Standard);
216        self
217    }
218
219    /// Set whether or not the search is case-sensitive.
220    ///
221    /// NOTE: This only applies to ASCII characters, if you require unicode
222    /// case insensitivity, you should use [RegexMatcher] or [RegexSetMatcher]
223    /// instead.
224    pub fn case_insensitive(&mut self, yes: bool) -> &mut Self {
225        self.case_insensitive = yes;
226        self
227    }
228
229    /// Configure the match semantics for this matcher
230    ///
231    /// NOTE: This function will panic if the provided option conflicts with
232    /// the `support_overlapping_matches` setting, as only the default semantics
233    /// are supported when overlapping matches are computed.
234    pub fn match_kind(&mut self, kind: MatchKind) -> &mut Self {
235        assert!(
236            !self.support_overlapping_matches,
237            "cannot support {kind:?} when overlapping matches are enabled"
238        );
239        self.match_kind = Some(kind);
240        self
241    }
242
243    /// Build the [SubstringSetMatcher]
244    ///
245    /// This function will panic if there are no patterns configured, or if
246    /// an incompatible configuration is provided.
247    pub fn build(self) -> DiagResult<SubstringSetMatcher<'a>> {
248        assert!(
249            !self.patterns.is_empty(),
250            "there must be at least one pattern in the set"
251        );
252
253        let kind = if self.support_overlapping_matches {
254            Some(AhoCorasickKind::DFA)
255        } else {
256            None
257        };
258        let match_kind = self.match_kind.unwrap_or(MatchKind::LeftmostLongest);
259        let start_kind = self.start_kind.unwrap_or(StartKind::Unanchored);
260        let mut builder = AhoCorasickBuilder::new();
261        builder
262            .ascii_case_insensitive(self.case_insensitive)
263            .match_kind(match_kind)
264            .start_kind(start_kind)
265            .kind(kind);
266        let searcher = builder
267            .build(self.patterns.iter().map(|p| p.as_bytes()))
268            .map_err(|err| {
269                let labels = self
270                    .patterns
271                    .iter()
272                    .map(|s| Label::new(s.span(), err.to_string()).into());
273                let diag = Diag::new("failed to build multi-substring aho-corasick searcher")
274                    .and_labels(labels)
275                    .with_help(format!(
276                        "search configuration: {kind:?}, {match_kind:?}, {start_kind:?}"
277                    ));
278                Report::from(diag)
279            })?;
280        Ok(SubstringSetMatcher {
281            patterns: self.patterns,
282            searcher,
283        })
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_multi_substring_matcher_overlapping() -> DiagResult<()> {
293        const INPUT: &str = "
294define void @sub1(i32* %p, i32 %v) {
295entry:
296        %0 = tail call i32 @llvm.atomic.load.sub.i32.p0i32(i32* %p, i32 %v)
297        ret void
298}
299
300define void @inc4(i64* %p) {
301entry:
302        %0 = tail call i64 @llvm.atomic.load.add.i64.p0i64(i64* %p, i64 1)
303        ret void
304}
305";
306        let mut context = TestContext::new();
307        context
308            .with_checks(
309                "
310CHECK-DAG: tail call i64
311CHECK-DAG: tail call i32
312",
313            )
314            .with_input(INPUT);
315
316        let pattern1 = Span::new(12..24, Cow::Borrowed("tail call i64"));
317        let pattern2 = Span::new(25..41, Cow::Borrowed("tail call i32"));
318        let matcher = SubstringSetMatcher::new(vec![pattern1, pattern2], &context.config)
319            .expect("expected pattern to be valid");
320        let mctx = context.match_context();
321        let input = mctx.search();
322        let result = matcher.try_match(input, &mctx)?;
323        let info = result.info.expect("expected match");
324        assert_eq!(info.span.offset(), 58);
325        assert_eq!(info.span.len(), 13);
326        assert_eq!(input.as_str(info.matched_range()), "tail call i32");
327        Ok(())
328    }
329
330    #[test]
331    fn test_multi_substring_matcher_overlapped() -> DiagResult<()> {
332        const INPUT: &str = "
333define void @sub1(i32* %p, i32 %v) {
334entry:
335        %0 = tail call i32 @llvm.atomic.load.sub.i32.p0i32(i32* %p, i32 %v)
336        ret void
337}
338
339define void @inc4(i64* %p) {
340entry:
341        %0 = tail call i64 @llvm.atomic.load.add.i64.p0i64(i64* %p, i64 1)
342        ret void
343}
344";
345        let mut context = TestContext::new();
346        context
347            .with_checks(
348                "
349CHECK-DAG: tail call i32
350CHECK-DAG: tail call
351",
352            )
353            .with_input(INPUT);
354
355        let pattern1 = Span::new(12..24, Cow::Borrowed("tail call i32"));
356        let pattern2 = Span::new(25..37, Cow::Borrowed("tail call"));
357        let matcher = SubstringSetMatcher::new(vec![pattern1, pattern2], &context.config)
358            .expect("expected pattern to be valid");
359        let mctx = context.match_context();
360        let input = mctx.search();
361        let result = matcher.try_match(input, &mctx)?;
362        let info = result.info.expect("expected match");
363        assert_eq!(info.span.offset(), 58);
364        assert_eq!(info.span.len(), 13);
365        assert_eq!(input.as_str(info.matched_range()), "tail call i32");
366        Ok(())
367    }
368
369    #[test]
370    fn test_multi_substring_matcher_disjoint() -> DiagResult<()> {
371        const INPUT: &str = "
372define void @sub1(i32* %p, i32 %v) {
373entry:
374        %0 = tail call i32 @llvm.atomic.load.sub.i32.p0i32(i32* %p, i32 %v)
375        ret void
376}
377
378define void @inc4(i64* %p) {
379entry:
380        %0 = tail call i64 @llvm.atomic.load.add.i64.p0i64(i64* %p, i64 1)
381        ret void
382}
383";
384        let mut context = TestContext::new();
385        context
386            .with_checks(
387                "
388CHECK-DAG: inc4
389CHECK-DAG: sub1
390",
391            )
392            .with_input(INPUT);
393
394        let pattern1 = Span::new(12..17, Cow::Borrowed("inc4"));
395        let pattern2 = Span::new(19..35, Cow::Borrowed("sub1"));
396        let matcher = SubstringSetMatcher::new(vec![pattern1, pattern2], &context.config)
397            .expect("expected pattern to be valid");
398        let mctx = context.match_context();
399        let input = mctx.search();
400        let result = matcher.try_match(input, &mctx)?;
401        let info = result.info.expect("expected match");
402        assert_eq!(info.span.offset(), 14);
403        assert_eq!(info.span.len(), 4);
404        assert_eq!(input.as_str(info.matched_range()), "sub1");
405        Ok(())
406    }
407
408    #[test]
409    fn test_multi_substring_matcher_anchored() -> DiagResult<()> {
410        const INPUT: &str = "
411define void @sub1(i32* %p, i32 %v) {
412entry:
413        %0 = tail call i32 @llvm.atomic.load.sub.i32.p0i32(i32* %p, i32 %v)
414        ret void
415}
416
417define void @inc4(i64* %p) {
418entry:
419        %0 = tail call i64 @llvm.atomic.load.add.i64.p0i64(i64* %p, i64 1)
420        ret void
421}
422";
423        let mut context = TestContext::new();
424        context
425            .with_checks(
426                "
427CHECK-DAG: @inc4
428CHECK-DAG: @sub1
429",
430            )
431            .with_input(INPUT);
432
433        let pattern1 = Span::new(0..0, Cow::Borrowed("@inc4"));
434        let pattern2 = Span::new(0..0, Cow::Borrowed("@sub1"));
435        let mut builder = SubstringSetMatcher::build();
436        builder
437            .with_patterns([pattern1, pattern2])
438            .support_anchored_search(true);
439        let matcher = builder.build().expect("expected pattern to be valid");
440        let mctx = context.match_context();
441        let input = mctx.search_range(13..).anchored(true);
442        let result = matcher.try_match(input, &mctx)?;
443        let info = result.info.expect("expected match");
444        assert_eq!(info.span.offset(), 13);
445        assert_eq!(info.span.len(), 5);
446        assert_eq!(input.as_str(info.matched_range()), "@sub1");
447
448        let input = mctx.search().anchored(true);
449        let result = matcher.try_match(input, &mctx)?;
450        assert_eq!(result.info, None);
451        Ok(())
452    }
453}