provenant/license_detection/
automaton.rs1use daachorse::DoubleArrayAhoCorasick;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct Match {
13 pub pattern: usize,
15 pub start: usize,
17 pub end: usize,
19}
20
21pub struct Automaton {
26 inner: DoubleArrayAhoCorasick<u32>,
27}
28
29impl std::fmt::Debug for Automaton {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("Automaton")
32 .field("num_states", &self.inner.num_states())
33 .field("heap_bytes", &self.inner.heap_bytes())
34 .finish()
35 }
36}
37
38impl Clone for Automaton {
39 fn clone(&self) -> Self {
40 let bytes = self.inner.serialize();
41 Self::deserialize_unchecked(&bytes)
42 }
43}
44
45impl Automaton {
46 pub fn empty() -> Self {
51 let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
54 match DoubleArrayAhoCorasick::new([dummy_pattern]) {
55 Ok(ac) => Self { inner: ac },
56 Err(_) => panic!("Failed to create empty automaton"),
57 }
58 }
59
60 #[allow(dead_code)]
64 pub fn build(patterns: &[&[u8]]) -> Self {
65 if patterns.is_empty() {
66 return Self::empty();
67 }
68 let non_empty: Vec<&[u8]> = patterns.iter().copied().filter(|p| !p.is_empty()).collect();
70 if non_empty.is_empty() {
71 return Self::empty();
72 }
73 match DoubleArrayAhoCorasick::new(non_empty) {
74 Ok(ac) => Self { inner: ac },
75 Err(_) => Self::empty(),
76 }
77 }
78
79 pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
89 FindOverlappingIter::new(&self.inner, haystack)
90 }
91
92 pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
97 let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
98 Self { inner: ac }
99 }
100
101 #[allow(dead_code)]
103 pub fn num_states(&self) -> usize {
104 self.inner.num_states()
105 }
106
107 #[allow(dead_code)]
109 pub fn heap_bytes(&self) -> usize {
110 self.inner.heap_bytes()
111 }
112
113 pub fn serialize_bytes(&self) -> Vec<u8> {
115 self.inner.serialize()
116 }
117}
118
119impl Default for Automaton {
120 fn default() -> Self {
121 Self::empty()
122 }
123}
124
125pub struct FindOverlappingIter {
134 inner: std::vec::IntoIter<daachorse::Match<u32>>,
135}
136
137impl FindOverlappingIter {
138 fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
139 let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
140 Self {
141 inner: matches.into_iter(),
142 }
143 }
144}
145
146impl Iterator for FindOverlappingIter {
147 type Item = Match;
148
149 fn next(&mut self) -> Option<Self::Item> {
150 loop {
151 let m = self.inner.next()?;
152 if m.start() % 2 == 0 {
155 return Some(Match {
156 pattern: m.value() as usize,
157 start: m.start(),
158 end: m.end(),
159 });
160 }
161 }
163 }
164}
165
166pub struct AutomatonBuilder {
170 patterns: Vec<Vec<u8>>,
171}
172
173impl AutomatonBuilder {
174 pub fn new() -> Self {
176 Self {
177 patterns: Vec::new(),
178 }
179 }
180
181 pub fn add_pattern(&mut self, pattern: &[u8]) {
185 if !pattern.is_empty() {
186 self.patterns.push(pattern.to_vec());
187 }
188 }
189
190 pub fn build(self) -> Automaton {
195 use std::collections::HashSet;
196
197 if self.patterns.is_empty() {
198 return Automaton::empty();
199 }
200
201 let mut seen: HashSet<Vec<u8>> = HashSet::new();
203 let mut unique_patterns: Vec<&[u8]> = Vec::new();
204 for pattern in &self.patterns {
205 if seen.insert(pattern.clone()) {
206 unique_patterns.push(pattern.as_slice());
207 }
208 }
209
210 if unique_patterns.is_empty() {
211 return Automaton::empty();
212 }
213
214 match DoubleArrayAhoCorasick::new(unique_patterns) {
215 Ok(ac) => Automaton { inner: ac },
216 Err(_) => Automaton::empty(),
217 }
218 }
219}
220
221impl Default for AutomatonBuilder {
222 fn default() -> Self {
223 Self::new()
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_empty_automaton() {
233 let ac = Automaton::empty();
234 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
235 assert!(matches.is_empty());
236 }
237
238 #[test]
239 fn test_build_with_patterns() {
240 let patterns: Vec<&[u8]> = vec![b"hello", b"world"];
241 let ac = Automaton::build(&patterns);
242 let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
243 assert_eq!(matches.len(), 2);
244 }
245
246 #[test]
247 fn test_token_boundary_filtering() {
248 let pattern: &[u8] = &[31, 49];
250 let ac = Automaton::build(&[pattern]);
251
252 let haystack: &[u8] = &[109, 31, 49, 74];
256 let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
257 assert!(
258 matches.is_empty(),
259 "Should not match across token boundaries"
260 );
261 }
262
263 #[test]
264 fn test_valid_token_match() {
265 let pattern: &[u8] = &[31, 49];
266 let ac = Automaton::build(&[pattern]);
267
268 let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
270 let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
271 assert_eq!(matches.len(), 1);
272 assert_eq!(matches[0].start, 2);
273 assert_eq!(matches[0].end, 4);
274 }
275
276 #[test]
277 fn test_builder() {
278 let mut builder = AutomatonBuilder::new();
279 builder.add_pattern(b"hello");
280 builder.add_pattern(b"world");
281 let ac = builder.build();
282
283 let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
284 assert_eq!(matches.len(), 2);
285 }
286
287 #[test]
288 fn test_builder_empty_patterns() {
289 let builder = AutomatonBuilder::new();
290 let ac = builder.build();
291 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
292 assert!(matches.is_empty());
293 }
294
295 #[test]
296 fn test_builder_skips_empty_patterns() {
297 let mut builder = AutomatonBuilder::new();
298 builder.add_pattern(b"");
299 builder.add_pattern(b"hello");
300 builder.add_pattern(b"");
301 let ac = builder.build();
302
303 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
304 assert_eq!(matches.len(), 1);
305 }
306
307 #[test]
308 fn test_serialize_deserialize() {
309 let patterns: Vec<&[u8]> = vec![b"hello", b"world", b"test"];
310 let ac1 = Automaton::build(&patterns);
311
312 let serialized = ac1.inner.serialize();
313 let ac2 = Automaton::deserialize_unchecked(&serialized);
314
315 let haystack = b"hello world test";
316 let matches1: Vec<_> = ac1.find_overlapping_iter(haystack).collect();
317 let matches2: Vec<_> = ac2.find_overlapping_iter(haystack).collect();
318
319 assert_eq!(matches1.len(), matches2.len());
320 for (m1, m2) in matches1.iter().zip(matches2.iter()) {
321 assert_eq!(m1.pattern, m2.pattern);
322 assert_eq!(m1.start, m2.start);
323 assert_eq!(m1.end, m2.end);
324 }
325 }
326
327 #[test]
328 fn test_overlapping_matches() {
329 let patterns: Vec<&[u8]> = vec![b"ab", b"bc", b"abc"];
330 let ac = Automaton::build(&patterns);
331
332 let matches: Vec<_> = ac.find_overlapping_iter(b"abc").collect();
333 assert!(matches.len() >= 2);
335 }
336}