provenant/license_detection/
automaton.rs1use daachorse::DoubleArrayAhoCorasick;
12use rancor::Fallible;
13use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
14use rkyv::{Archive, Deserialize, Place, Serialize};
15
16pub struct AsBytes;
18
19impl ArchiveWith<Automaton> for AsBytes {
20 type Archived = <Vec<u8> as Archive>::Archived;
21 type Resolver = <Vec<u8> as Archive>::Resolver;
22
23 fn resolve_with(field: &Automaton, resolver: Self::Resolver, out: Place<Self::Archived>) {
24 field.serialize_bytes().resolve(resolver, out);
25 }
26}
27
28impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized> SerializeWith<Automaton, S>
29 for AsBytes
30{
31 fn serialize_with(field: &Automaton, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
32 field.serialize_bytes().serialize(serializer)
33 }
34}
35
36impl<D: Fallible + ?Sized> DeserializeWith<<Vec<u8> as Archive>::Archived, Automaton, D> for AsBytes
37where
38 <Vec<u8> as Archive>::Archived: Deserialize<Vec<u8>, D>,
39{
40 fn deserialize_with(
41 field: &<Vec<u8> as Archive>::Archived,
42 deserializer: &mut D,
43 ) -> Result<Automaton, D::Error> {
44 let bytes: Vec<u8> = field.deserialize(deserializer)?;
45 Ok(Automaton::deserialize_unchecked(&bytes))
46 }
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct Match {
52 pub pattern: usize,
54 pub start: usize,
56 pub end: usize,
58}
59
60pub struct Automaton {
65 inner: DoubleArrayAhoCorasick<u32>,
66}
67
68impl std::fmt::Debug for Automaton {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 f.debug_struct("Automaton")
71 .field("num_states", &self.inner.num_states())
72 .field("heap_bytes", &self.inner.heap_bytes())
73 .finish()
74 }
75}
76
77impl Clone for Automaton {
78 fn clone(&self) -> Self {
79 let bytes = self.inner.serialize();
80 Self::deserialize_unchecked(&bytes)
81 }
82}
83
84impl Automaton {
85 pub fn empty() -> Self {
90 let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
93 match DoubleArrayAhoCorasick::new([dummy_pattern]) {
94 Ok(ac) => Self { inner: ac },
95 Err(_) => panic!("Failed to create empty automaton"),
96 }
97 }
98
99 #[allow(dead_code)]
103 pub fn build(patterns: &[&[u8]]) -> Self {
104 if patterns.is_empty() {
105 return Self::empty();
106 }
107 let non_empty: Vec<&[u8]> = patterns.iter().copied().filter(|p| !p.is_empty()).collect();
109 if non_empty.is_empty() {
110 return Self::empty();
111 }
112 match DoubleArrayAhoCorasick::new(non_empty) {
113 Ok(ac) => Self { inner: ac },
114 Err(_) => Self::empty(),
115 }
116 }
117
118 pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
128 FindOverlappingIter::new(&self.inner, haystack)
129 }
130
131 pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
136 let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
137 Self { inner: ac }
138 }
139
140 #[allow(dead_code)]
142 pub fn num_states(&self) -> usize {
143 self.inner.num_states()
144 }
145
146 #[allow(dead_code)]
148 pub fn heap_bytes(&self) -> usize {
149 self.inner.heap_bytes()
150 }
151
152 pub fn serialize_bytes(&self) -> Vec<u8> {
154 self.inner.serialize()
155 }
156}
157
158impl Default for Automaton {
159 fn default() -> Self {
160 Self::empty()
161 }
162}
163
164pub struct FindOverlappingIter {
173 inner: std::vec::IntoIter<daachorse::Match<u32>>,
174}
175
176impl FindOverlappingIter {
177 fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
178 let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
179 Self {
180 inner: matches.into_iter(),
181 }
182 }
183}
184
185impl Iterator for FindOverlappingIter {
186 type Item = Match;
187
188 fn next(&mut self) -> Option<Self::Item> {
189 loop {
190 let m = self.inner.next()?;
191 if m.start() % 2 == 0 {
194 return Some(Match {
195 pattern: m.value() as usize,
196 start: m.start(),
197 end: m.end(),
198 });
199 }
200 }
202 }
203}
204
205pub struct AutomatonBuilder {
209 patterns: Vec<Vec<u8>>,
210}
211
212impl AutomatonBuilder {
213 pub fn new() -> Self {
215 Self {
216 patterns: Vec::new(),
217 }
218 }
219
220 pub fn add_pattern(&mut self, pattern: &[u8]) {
224 if !pattern.is_empty() {
225 self.patterns.push(pattern.to_vec());
226 }
227 }
228
229 pub fn build(self) -> Automaton {
234 use std::collections::HashSet;
235
236 if self.patterns.is_empty() {
237 return Automaton::empty();
238 }
239
240 let mut seen: HashSet<Vec<u8>> = HashSet::new();
242 let mut unique_patterns: Vec<&[u8]> = Vec::new();
243 for pattern in &self.patterns {
244 if seen.insert(pattern.clone()) {
245 unique_patterns.push(pattern.as_slice());
246 }
247 }
248
249 if unique_patterns.is_empty() {
250 return Automaton::empty();
251 }
252
253 match DoubleArrayAhoCorasick::new(unique_patterns) {
254 Ok(ac) => Automaton { inner: ac },
255 Err(_) => Automaton::empty(),
256 }
257 }
258}
259
260impl Default for AutomatonBuilder {
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_empty_automaton() {
272 let ac = Automaton::empty();
273 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
274 assert!(matches.is_empty());
275 }
276
277 #[test]
278 fn test_build_with_patterns() {
279 let patterns: Vec<&[u8]> = vec![b"hello", b"world"];
280 let ac = Automaton::build(&patterns);
281 let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
282 assert_eq!(matches.len(), 2);
283 }
284
285 #[test]
286 fn test_token_boundary_filtering() {
287 let pattern: &[u8] = &[31, 49];
289 let ac = Automaton::build(&[pattern]);
290
291 let haystack: &[u8] = &[109, 31, 49, 74];
295 let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
296 assert!(
297 matches.is_empty(),
298 "Should not match across token boundaries"
299 );
300 }
301
302 #[test]
303 fn test_valid_token_match() {
304 let pattern: &[u8] = &[31, 49];
305 let ac = Automaton::build(&[pattern]);
306
307 let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
309 let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
310 assert_eq!(matches.len(), 1);
311 assert_eq!(matches[0].start, 2);
312 assert_eq!(matches[0].end, 4);
313 }
314
315 #[test]
316 fn test_builder() {
317 let mut builder = AutomatonBuilder::new();
318 builder.add_pattern(b"hello");
319 builder.add_pattern(b"world");
320 let ac = builder.build();
321
322 let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
323 assert_eq!(matches.len(), 2);
324 }
325
326 #[test]
327 fn test_builder_empty_patterns() {
328 let builder = AutomatonBuilder::new();
329 let ac = builder.build();
330 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
331 assert!(matches.is_empty());
332 }
333
334 #[test]
335 fn test_builder_skips_empty_patterns() {
336 let mut builder = AutomatonBuilder::new();
337 builder.add_pattern(b"");
338 builder.add_pattern(b"hello");
339 builder.add_pattern(b"");
340 let ac = builder.build();
341
342 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
343 assert_eq!(matches.len(), 1);
344 }
345
346 #[test]
347 fn test_serialize_deserialize() {
348 let patterns: Vec<&[u8]> = vec![b"hello", b"world", b"test"];
349 let ac1 = Automaton::build(&patterns);
350
351 let serialized = ac1.inner.serialize();
352 let ac2 = Automaton::deserialize_unchecked(&serialized);
353
354 let haystack = b"hello world test";
355 let matches1: Vec<_> = ac1.find_overlapping_iter(haystack).collect();
356 let matches2: Vec<_> = ac2.find_overlapping_iter(haystack).collect();
357
358 assert_eq!(matches1.len(), matches2.len());
359 for (m1, m2) in matches1.iter().zip(matches2.iter()) {
360 assert_eq!(m1.pattern, m2.pattern);
361 assert_eq!(m1.start, m2.start);
362 assert_eq!(m1.end, m2.end);
363 }
364 }
365
366 #[test]
367 fn test_overlapping_matches() {
368 let patterns: Vec<&[u8]> = vec![b"ab", b"bc", b"abc"];
369 let ac = Automaton::build(&patterns);
370
371 let matches: Vec<_> = ac.find_overlapping_iter(b"abc").collect();
372 assert!(matches.len() >= 2);
374 }
375}