provenant/license_detection/
automaton.rs1use daachorse::DoubleArrayAhoCorasick;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct Match {
13 pub pattern: usize,
15 pub start: usize,
17 pub end: usize,
19}
20
21pub struct Automaton {
26 inner: DoubleArrayAhoCorasick<u32>,
27}
28
29impl std::fmt::Debug for Automaton {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("Automaton")
32 .field("num_states", &self.inner.num_states())
33 .field("heap_bytes", &self.inner.heap_bytes())
34 .finish()
35 }
36}
37
38impl Clone for Automaton {
39 fn clone(&self) -> Self {
40 let bytes = self.inner.serialize();
41 Self::deserialize_unchecked(&bytes)
42 }
43}
44
45impl Automaton {
46 pub fn empty() -> Self {
51 let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
54 match DoubleArrayAhoCorasick::new([dummy_pattern]) {
55 Ok(ac) => Self { inner: ac },
56 Err(_) => panic!("Failed to create empty automaton"),
57 }
58 }
59
60 #[allow(dead_code)]
64 pub fn build(patterns: &[&[u8]]) -> Self {
65 if patterns.is_empty() {
66 return Self::empty();
67 }
68 let non_empty: Vec<&[u8]> = patterns.iter().copied().filter(|p| !p.is_empty()).collect();
70 if non_empty.is_empty() {
71 return Self::empty();
72 }
73 match DoubleArrayAhoCorasick::new(non_empty) {
74 Ok(ac) => Self { inner: ac },
75 Err(_) => Self::empty(),
76 }
77 }
78
79 pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
89 FindOverlappingIter::new(&self.inner, haystack)
90 }
91
92 pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
97 let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
98 Self { inner: ac }
99 }
100
101 #[allow(dead_code)]
103 pub fn num_states(&self) -> usize {
104 self.inner.num_states()
105 }
106
107 #[allow(dead_code)]
109 pub fn heap_bytes(&self) -> usize {
110 self.inner.heap_bytes()
111 }
112}
113
114impl Default for Automaton {
115 fn default() -> Self {
116 Self::empty()
117 }
118}
119
120pub struct FindOverlappingIter {
129 inner: std::vec::IntoIter<daachorse::Match<u32>>,
130}
131
132impl FindOverlappingIter {
133 fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
134 let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
135 Self {
136 inner: matches.into_iter(),
137 }
138 }
139}
140
141impl Iterator for FindOverlappingIter {
142 type Item = Match;
143
144 fn next(&mut self) -> Option<Self::Item> {
145 loop {
146 let m = self.inner.next()?;
147 if m.start() % 2 == 0 {
150 return Some(Match {
151 pattern: m.value() as usize,
152 start: m.start(),
153 end: m.end(),
154 });
155 }
156 }
158 }
159}
160
161pub struct AutomatonBuilder {
165 patterns: Vec<Vec<u8>>,
166}
167
168impl AutomatonBuilder {
169 pub fn new() -> Self {
171 Self {
172 patterns: Vec::new(),
173 }
174 }
175
176 pub fn add_pattern(&mut self, pattern: &[u8]) {
180 if !pattern.is_empty() {
181 self.patterns.push(pattern.to_vec());
182 }
183 }
184
185 pub fn build(self) -> Automaton {
190 use std::collections::HashSet;
191
192 if self.patterns.is_empty() {
193 return Automaton::empty();
194 }
195
196 let mut seen: HashSet<Vec<u8>> = HashSet::new();
198 let mut unique_patterns: Vec<&[u8]> = Vec::new();
199 for pattern in &self.patterns {
200 if seen.insert(pattern.clone()) {
201 unique_patterns.push(pattern.as_slice());
202 }
203 }
204
205 if unique_patterns.is_empty() {
206 return Automaton::empty();
207 }
208
209 match DoubleArrayAhoCorasick::new(unique_patterns) {
210 Ok(ac) => Automaton { inner: ac },
211 Err(_) => Automaton::empty(),
212 }
213 }
214}
215
216impl Default for AutomatonBuilder {
217 fn default() -> Self {
218 Self::new()
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn test_empty_automaton() {
228 let ac = Automaton::empty();
229 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
230 assert!(matches.is_empty());
231 }
232
233 #[test]
234 fn test_build_with_patterns() {
235 let patterns: Vec<&[u8]> = vec![b"hello", b"world"];
236 let ac = Automaton::build(&patterns);
237 let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
238 assert_eq!(matches.len(), 2);
239 }
240
241 #[test]
242 fn test_token_boundary_filtering() {
243 let pattern: &[u8] = &[31, 49];
245 let ac = Automaton::build(&[pattern]);
246
247 let haystack: &[u8] = &[109, 31, 49, 74];
251 let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
252 assert!(
253 matches.is_empty(),
254 "Should not match across token boundaries"
255 );
256 }
257
258 #[test]
259 fn test_valid_token_match() {
260 let pattern: &[u8] = &[31, 49];
261 let ac = Automaton::build(&[pattern]);
262
263 let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
265 let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
266 assert_eq!(matches.len(), 1);
267 assert_eq!(matches[0].start, 2);
268 assert_eq!(matches[0].end, 4);
269 }
270
271 #[test]
272 fn test_builder() {
273 let mut builder = AutomatonBuilder::new();
274 builder.add_pattern(b"hello");
275 builder.add_pattern(b"world");
276 let ac = builder.build();
277
278 let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
279 assert_eq!(matches.len(), 2);
280 }
281
282 #[test]
283 fn test_builder_empty_patterns() {
284 let builder = AutomatonBuilder::new();
285 let ac = builder.build();
286 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
287 assert!(matches.is_empty());
288 }
289
290 #[test]
291 fn test_builder_skips_empty_patterns() {
292 let mut builder = AutomatonBuilder::new();
293 builder.add_pattern(b"");
294 builder.add_pattern(b"hello");
295 builder.add_pattern(b"");
296 let ac = builder.build();
297
298 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
299 assert_eq!(matches.len(), 1);
300 }
301
302 #[test]
303 fn test_serialize_deserialize() {
304 let patterns: Vec<&[u8]> = vec![b"hello", b"world", b"test"];
305 let ac1 = Automaton::build(&patterns);
306
307 let serialized = ac1.inner.serialize();
308 let ac2 = Automaton::deserialize_unchecked(&serialized);
309
310 let haystack = b"hello world test";
311 let matches1: Vec<_> = ac1.find_overlapping_iter(haystack).collect();
312 let matches2: Vec<_> = ac2.find_overlapping_iter(haystack).collect();
313
314 assert_eq!(matches1.len(), matches2.len());
315 for (m1, m2) in matches1.iter().zip(matches2.iter()) {
316 assert_eq!(m1.pattern, m2.pattern);
317 assert_eq!(m1.start, m2.start);
318 assert_eq!(m1.end, m2.end);
319 }
320 }
321
322 #[test]
323 fn test_overlapping_matches() {
324 let patterns: Vec<&[u8]> = vec![b"ab", b"bc", b"abc"];
325 let ac = Automaton::build(&patterns);
326
327 let matches: Vec<_> = ac.find_overlapping_iter(b"abc").collect();
328 assert!(matches.len() >= 2);
330 }
331}