1use std::{
2 borrow::Cow,
3 collections::HashMap,
4 path::Path,
5 sync::{Arc, RwLock},
6};
7
8use super::{nfa::Nfa, nfa::StateId};
9
10#[derive(Clone)]
13pub struct Matcher {
14 nfa: Nfa,
15 transition_cache: Arc<RwLock<HashMap<String, Vec<StateId>>>>,
16}
17
18impl Matcher {
19 pub(crate) fn new(nfa: Nfa) -> Matcher {
20 Self {
21 nfa,
22 transition_cache: Arc::new(RwLock::new(HashMap::new())),
23 }
24 }
25
26 pub fn matching_patterns(&self, path: impl AsRef<Path>) -> Vec<usize> {
30 let components = path
31 .as_ref()
32 .iter()
33 .map(|c| c.to_string_lossy())
34 .collect::<Vec<_>>();
35 let initial_states = self.nfa.initial_states();
36 let final_states = self.next_states(&components, initial_states);
37
38 let mut matches = Vec::new();
39 for state_id in final_states {
40 if let Some(pattern_ids) = &self.nfa.state(state_id).terminal_for_patterns {
43 matches.extend(pattern_ids.iter().copied());
44 }
45 }
46 matches
47 }
48
49 fn next_states(&self, path_segments: &[Cow<str>], start_states: Vec<StateId>) -> Vec<StateId> {
54 if path_segments.is_empty() {
56 return start_states;
57 }
58
59 let subpath_segments = &path_segments[..path_segments.len() - 1];
61 let subpath = subpath_segments.join("/");
62
63 let cached_states = self.get_cached_states_for(&subpath);
65 let states = if let Some(states) = cached_states {
66 states
67 } else {
68 let states = self.next_states(subpath_segments, start_states);
70 self.set_cached_states_for(subpath, states.clone());
71 states
72 };
73
74 let segment = path_segments.last().unwrap();
79 let mut next_states = Vec::new();
80 for state_id in states {
81 self.nfa
82 .transitions_from(state_id)
83 .filter(|transition| transition.is_match(segment))
84 .for_each(|transition| next_states.push(transition.target));
85 }
86
87 let epsilon_nodes = next_states
89 .iter()
90 .flat_map(|&state_id| self.nfa.epsilon_transitions_from(state_id))
91 .collect::<Vec<_>>();
92 next_states.extend(epsilon_nodes);
93 next_states
94 }
95
96 fn get_cached_states_for(&self, path: &str) -> Option<Vec<StateId>> {
97 self.transition_cache
98 .read()
99 .expect("valid lock")
100 .get(path)
101 .cloned()
102 }
103
104 fn set_cached_states_for(&self, path: String, states: Vec<StateId>) {
105 self.transition_cache
106 .write()
107 .expect("valid lock")
108 .insert(path, states);
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use std::collections::HashSet;
115
116 use crate::patternset::Builder;
117
118 use super::*;
119
120 #[test]
121 fn test_literals() {
122 let patterns = [
123 "/src/parser/mod.rs",
124 "/lib/parser/parse.rs",
125 "/bin/parser/mod.rs",
126 "mod.rs",
127 ];
128 let matcher = matcher_for_patterns(&patterns);
129
130 assert_matches(&matcher, "src/parser/mod.rs", &patterns, &[0, 3]);
131 assert_matches(&matcher, "lib/parser/parse.rs", &patterns, &[1]);
132 assert_matches(&matcher, "lib/parser/mod.rs", &patterns, &[3]);
133 assert_matches(&matcher, "lib/parser/util.rs", &patterns, &[]);
134 assert_matches(&matcher, "src/lexer/mod.rs", &patterns, &[3]);
135 assert_matches(&matcher, "src/parser/mod.go", &patterns, &[]);
136 }
137
138 #[test]
139 fn test_prefixes() {
140 let patterns = ["src", "src/parser", "src/parser/"];
141 let matcher = matcher_for_patterns(&patterns);
142
143 assert_matches(&matcher, "src/parser/mod.rs", &patterns, &[0, 1, 2]);
144 assert_matches(&matcher, "src/parser", &patterns, &[0, 1]);
145 assert_matches(&matcher, "foo/src/parser/mod.rs", &patterns, &[0]);
146 }
147
148 #[test]
149 fn test_anchoring() {
150 let patterns = ["/script/foo", "script/foo", "/foo", "foo"];
151 let matcher = matcher_for_patterns(&patterns);
152
153 assert_matches(&matcher, "script/foo", &patterns, &[0, 1, 3]);
154 assert_matches(&matcher, "foo", &patterns, &[2, 3]);
155 assert_matches(&matcher, "bar/script/foo", &patterns, &[3]);
156 }
157
158 #[test]
159 fn test_wildcards() {
160 let patterns = [
161 "src/*/mod.rs",
162 "src/parser/*",
163 "*/*/mod.rs",
164 "src/parser/*/",
165 ];
166 let matcher = matcher_for_patterns(&patterns);
167
168 assert_matches(&matcher, "src/parser/mod.rs", &patterns, &[0, 1, 2]);
169 assert_matches(&matcher, "src/lexer/mod.rs", &patterns, &[0, 2]);
170 assert_matches(&matcher, "src/parser/parser.rs", &patterns, &[1]);
171 assert_matches(&matcher, "test/lexer/mod.rs", &patterns, &[2]);
172 assert_matches(&matcher, "parser/mod.rs", &patterns, &[]);
173 assert_matches(&matcher, "src/parser/subdir/thing.rs", &patterns, &[3]);
174 }
175
176 #[test]
177 fn test_trailing_wildcards() {
178 let patterns = ["/mammals/*", "/fish/*/"];
179 let matcher = matcher_for_patterns(&patterns);
180
181 assert_matches(&matcher, "mammals", &patterns, &[]);
182 assert_matches(&matcher, "mammals/equus", &patterns, &[0]);
183 assert_matches(&matcher, "mammals/equus/zebra", &patterns, &[]);
184
185 assert_matches(&matcher, "fish", &patterns, &[]);
186 assert_matches(&matcher, "fish/gaddus", &patterns, &[]);
187 assert_matches(&matcher, "fish/gaddus/cod", &patterns, &[1]);
188 }
189
190 #[test]
191 fn test_complex_patterns() {
192 let patterns = ["/src/parser/*.rs", "/src/p*/*.*"];
193 let matcher = matcher_for_patterns(&patterns);
194
195 assert_matches(&matcher, "src/parser/mod.rs", &patterns, &[0, 1]);
196 assert_matches(&matcher, "src/p/lib.go", &patterns, &[1]);
197 assert_matches(&matcher, "src/parser/README", &patterns, &[]);
198 }
199
200 #[test]
201 fn test_leading_double_stars() {
202 let patterns = ["/**/baz", "/**/bar/baz"];
203 let matcher = matcher_for_patterns(&patterns);
204
205 assert_matches(&matcher, "x/y/baz", &patterns, &[0]);
206 assert_matches(&matcher, "x/bar/baz", &patterns, &[0, 1]);
207 assert_matches(&matcher, "baz", &patterns, &[0]);
208 }
209
210 #[test]
211 fn test_infix_double_stars() {
212 let patterns = ["/foo/**/qux", "/foo/qux"];
213 let matcher = matcher_for_patterns(&patterns);
214
215 assert_matches(&matcher, "foo/qux", &patterns, &[0, 1]);
216 assert_matches(&matcher, "foo/bar/qux", &patterns, &[0]);
217 assert_matches(&matcher, "foo/bar/baz/qux", &patterns, &[0]);
218 assert_matches(&matcher, "foo/bar", &patterns, &[]);
219 assert_matches(&matcher, "bar/qux", &patterns, &[]);
220 }
221
222 #[test]
223 fn test_trailing_double_stars() {
224 let patterns = ["foo/**", "**"];
225 let matcher = matcher_for_patterns(&patterns);
226
227 assert_matches(&matcher, "foo", &patterns, &[1]);
228 assert_matches(&matcher, "bar", &patterns, &[1]);
229 assert_matches(&matcher, "foo/bar", &patterns, &[0, 1]);
230 assert_matches(&matcher, "x/y/baz", &patterns, &[1]);
231 assert_matches(&matcher, "foo/bar/baz", &patterns, &[0, 1]);
232 }
233
234 #[test]
235 fn test_escape_sequences() {
236 let patterns = ["f\\*o", "a*b\\??", "\\*qux", "bar\\*", "\\*"];
237 let matcher = matcher_for_patterns(&patterns);
238
239 assert_matches(&matcher, "f*o", &patterns, &[0]);
240 assert_matches(&matcher, "foo", &patterns, &[]);
241 assert_matches(&matcher, "axb?!", &patterns, &[1]);
242 assert_matches(&matcher, "axb?", &patterns, &[]);
243 assert_matches(&matcher, "axbc!", &patterns, &[]);
244 assert_matches(&matcher, "*qux", &patterns, &[2]);
245 assert_matches(&matcher, "xqux", &patterns, &[]);
246 assert_matches(&matcher, "bar*", &patterns, &[3]);
247 assert_matches(&matcher, "bar", &patterns, &[]);
248 assert_matches(&matcher, "*", &patterns, &[4]);
249 assert_matches(&matcher, "a", &patterns, &[]);
250 }
251
252 fn assert_matches(matcher: &Matcher, path: &str, patterns: &[&str], expected: &[usize]) {
253 assert_eq!(
254 HashSet::<usize>::from_iter(matcher.matching_patterns(path).into_iter()),
255 HashSet::from_iter(expected.iter().copied()),
256 "expected {:?} to match {:?}",
257 path,
258 expected.iter().map(|&i| patterns[i]).collect::<Vec<_>>(),
259 );
260 }
261
262 fn matcher_for_patterns(patterns: &[&str]) -> Matcher {
263 let mut builder = Builder::new();
264 for pattern in patterns {
265 builder.add(pattern);
266 }
267 builder.build()
268 }
269}