codeowners_rs/patternset/
matcher.rs

1use std::{
2    borrow::Cow,
3    collections::HashMap,
4    path::Path,
5    sync::{Arc, RwLock},
6};
7
8use super::{nfa::Nfa, nfa::StateId};
9
10/// Matches a path against a set of patterns. Includes a thread-safe transition
11/// cache to speed up subsequent lookups. Created using a [`super::Builder`].
12#[derive(Clone)]
13pub struct Matcher {
14    nfa: Nfa,
15    transition_cache: Arc<RwLock<HashMap<String, Vec<StateId>>>>,
16}
17
18impl Matcher {
19    pub(crate) fn new(nfa: Nfa) -> Matcher {
20        Self {
21            nfa,
22            transition_cache: Arc::new(RwLock::new(HashMap::new())),
23        }
24    }
25
26    /// Match a path against the patterns in the set. Returns a list of pattern
27    /// indices that match the path. The pattern indices match the order in which
28    /// the patterns were added to the builder.
29    pub fn matching_patterns(&self, path: impl AsRef<Path>) -> Vec<usize> {
30        let components = path
31            .as_ref()
32            .iter()
33            .map(|c| c.to_string_lossy())
34            .collect::<Vec<_>>();
35        let initial_states = self.nfa.initial_states();
36        let final_states = self.next_states(&components, initial_states);
37
38        let mut matches = Vec::new();
39        for state_id in final_states {
40            // After processing the path, find the states we're in that are
41            // terminal, and return the pattern ids for those states.
42            if let Some(pattern_ids) = &self.nfa.state(state_id).terminal_for_patterns {
43                matches.extend(pattern_ids.iter().copied());
44            }
45        }
46        matches
47    }
48
49    // Given a set of states and a slice of path components, return the set of
50    // states we're in after stepping through the NFA. This is the core of the
51    // matching logic. `next_states` calls itself recursively until the path
52    // segment slice is empty.
53    fn next_states(&self, path_segments: &[Cow<str>], start_states: Vec<StateId>) -> Vec<StateId> {
54        // Base case - no more path segments to match
55        if path_segments.is_empty() {
56            return start_states;
57        }
58
59        // Get the states for the current path's prefix
60        let subpath_segments = &path_segments[..path_segments.len() - 1];
61        let subpath = subpath_segments.join("/");
62
63        // Start by checking the cache
64        let cached_states = self.get_cached_states_for(&subpath);
65        let states = if let Some(states) = cached_states {
66            states
67        } else {
68            // If the cache doesn't have the states, recursively compute them
69            let states = self.next_states(subpath_segments, start_states);
70            self.set_cached_states_for(subpath, states.clone());
71            states
72        };
73
74        // Now that we have the states for the current path's prefix, compute the
75        // next states for the current path by following the matching transitions for
76        // the current set of states we're in. The `unwrap` won't panic because we
77        // checked that the slice isn't empty above.
78        let segment = path_segments.last().unwrap();
79        let mut next_states = Vec::new();
80        for state_id in states {
81            self.nfa
82                .transitions_from(state_id)
83                .filter(|transition| transition.is_match(segment))
84                .for_each(|transition| next_states.push(transition.target));
85        }
86
87        // Automatically traverse epsilon edges
88        let epsilon_nodes = next_states
89            .iter()
90            .flat_map(|&state_id| self.nfa.epsilon_transitions_from(state_id))
91            .collect::<Vec<_>>();
92        next_states.extend(epsilon_nodes);
93        next_states
94    }
95
96    fn get_cached_states_for(&self, path: &str) -> Option<Vec<StateId>> {
97        self.transition_cache
98            .read()
99            .expect("valid lock")
100            .get(path)
101            .cloned()
102    }
103
104    fn set_cached_states_for(&self, path: String, states: Vec<StateId>) {
105        self.transition_cache
106            .write()
107            .expect("valid lock")
108            .insert(path, states);
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use std::collections::HashSet;
115
116    use crate::patternset::Builder;
117
118    use super::*;
119
120    #[test]
121    fn test_literals() {
122        let patterns = [
123            "/src/parser/mod.rs",
124            "/lib/parser/parse.rs",
125            "/bin/parser/mod.rs",
126            "mod.rs",
127        ];
128        let matcher = matcher_for_patterns(&patterns);
129
130        assert_matches(&matcher, "src/parser/mod.rs", &patterns, &[0, 3]);
131        assert_matches(&matcher, "lib/parser/parse.rs", &patterns, &[1]);
132        assert_matches(&matcher, "lib/parser/mod.rs", &patterns, &[3]);
133        assert_matches(&matcher, "lib/parser/util.rs", &patterns, &[]);
134        assert_matches(&matcher, "src/lexer/mod.rs", &patterns, &[3]);
135        assert_matches(&matcher, "src/parser/mod.go", &patterns, &[]);
136    }
137
138    #[test]
139    fn test_prefixes() {
140        let patterns = ["src", "src/parser", "src/parser/"];
141        let matcher = matcher_for_patterns(&patterns);
142
143        assert_matches(&matcher, "src/parser/mod.rs", &patterns, &[0, 1, 2]);
144        assert_matches(&matcher, "src/parser", &patterns, &[0, 1]);
145        assert_matches(&matcher, "foo/src/parser/mod.rs", &patterns, &[0]);
146    }
147
148    #[test]
149    fn test_anchoring() {
150        let patterns = ["/script/foo", "script/foo", "/foo", "foo"];
151        let matcher = matcher_for_patterns(&patterns);
152
153        assert_matches(&matcher, "script/foo", &patterns, &[0, 1, 3]);
154        assert_matches(&matcher, "foo", &patterns, &[2, 3]);
155        assert_matches(&matcher, "bar/script/foo", &patterns, &[3]);
156    }
157
158    #[test]
159    fn test_wildcards() {
160        let patterns = [
161            "src/*/mod.rs",
162            "src/parser/*",
163            "*/*/mod.rs",
164            "src/parser/*/",
165        ];
166        let matcher = matcher_for_patterns(&patterns);
167
168        assert_matches(&matcher, "src/parser/mod.rs", &patterns, &[0, 1, 2]);
169        assert_matches(&matcher, "src/lexer/mod.rs", &patterns, &[0, 2]);
170        assert_matches(&matcher, "src/parser/parser.rs", &patterns, &[1]);
171        assert_matches(&matcher, "test/lexer/mod.rs", &patterns, &[2]);
172        assert_matches(&matcher, "parser/mod.rs", &patterns, &[]);
173        assert_matches(&matcher, "src/parser/subdir/thing.rs", &patterns, &[3]);
174    }
175
176    #[test]
177    fn test_trailing_wildcards() {
178        let patterns = ["/mammals/*", "/fish/*/"];
179        let matcher = matcher_for_patterns(&patterns);
180
181        assert_matches(&matcher, "mammals", &patterns, &[]);
182        assert_matches(&matcher, "mammals/equus", &patterns, &[0]);
183        assert_matches(&matcher, "mammals/equus/zebra", &patterns, &[]);
184
185        assert_matches(&matcher, "fish", &patterns, &[]);
186        assert_matches(&matcher, "fish/gaddus", &patterns, &[]);
187        assert_matches(&matcher, "fish/gaddus/cod", &patterns, &[1]);
188    }
189
190    #[test]
191    fn test_complex_patterns() {
192        let patterns = ["/src/parser/*.rs", "/src/p*/*.*"];
193        let matcher = matcher_for_patterns(&patterns);
194
195        assert_matches(&matcher, "src/parser/mod.rs", &patterns, &[0, 1]);
196        assert_matches(&matcher, "src/p/lib.go", &patterns, &[1]);
197        assert_matches(&matcher, "src/parser/README", &patterns, &[]);
198    }
199
200    #[test]
201    fn test_leading_double_stars() {
202        let patterns = ["/**/baz", "/**/bar/baz"];
203        let matcher = matcher_for_patterns(&patterns);
204
205        assert_matches(&matcher, "x/y/baz", &patterns, &[0]);
206        assert_matches(&matcher, "x/bar/baz", &patterns, &[0, 1]);
207        assert_matches(&matcher, "baz", &patterns, &[0]);
208    }
209
210    #[test]
211    fn test_infix_double_stars() {
212        let patterns = ["/foo/**/qux", "/foo/qux"];
213        let matcher = matcher_for_patterns(&patterns);
214
215        assert_matches(&matcher, "foo/qux", &patterns, &[0, 1]);
216        assert_matches(&matcher, "foo/bar/qux", &patterns, &[0]);
217        assert_matches(&matcher, "foo/bar/baz/qux", &patterns, &[0]);
218        assert_matches(&matcher, "foo/bar", &patterns, &[]);
219        assert_matches(&matcher, "bar/qux", &patterns, &[]);
220    }
221
222    #[test]
223    fn test_trailing_double_stars() {
224        let patterns = ["foo/**", "**"];
225        let matcher = matcher_for_patterns(&patterns);
226
227        assert_matches(&matcher, "foo", &patterns, &[1]);
228        assert_matches(&matcher, "bar", &patterns, &[1]);
229        assert_matches(&matcher, "foo/bar", &patterns, &[0, 1]);
230        assert_matches(&matcher, "x/y/baz", &patterns, &[1]);
231        assert_matches(&matcher, "foo/bar/baz", &patterns, &[0, 1]);
232    }
233
234    #[test]
235    fn test_escape_sequences() {
236        let patterns = ["f\\*o", "a*b\\??", "\\*qux", "bar\\*", "\\*"];
237        let matcher = matcher_for_patterns(&patterns);
238
239        assert_matches(&matcher, "f*o", &patterns, &[0]);
240        assert_matches(&matcher, "foo", &patterns, &[]);
241        assert_matches(&matcher, "axb?!", &patterns, &[1]);
242        assert_matches(&matcher, "axb?", &patterns, &[]);
243        assert_matches(&matcher, "axbc!", &patterns, &[]);
244        assert_matches(&matcher, "*qux", &patterns, &[2]);
245        assert_matches(&matcher, "xqux", &patterns, &[]);
246        assert_matches(&matcher, "bar*", &patterns, &[3]);
247        assert_matches(&matcher, "bar", &patterns, &[]);
248        assert_matches(&matcher, "*", &patterns, &[4]);
249        assert_matches(&matcher, "a", &patterns, &[]);
250    }
251
252    fn assert_matches(matcher: &Matcher, path: &str, patterns: &[&str], expected: &[usize]) {
253        assert_eq!(
254            HashSet::<usize>::from_iter(matcher.matching_patterns(path).into_iter()),
255            HashSet::from_iter(expected.iter().copied()),
256            "expected {:?} to match {:?}",
257            path,
258            expected.iter().map(|&i| patterns[i]).collect::<Vec<_>>(),
259        );
260    }
261
262    fn matcher_for_patterns(patterns: &[&str]) -> Matcher {
263        let mut builder = Builder::new();
264        for pattern in patterns {
265            builder.add(pattern);
266        }
267        builder.build()
268    }
269}