provenant/license_detection/
automaton.rs1use daachorse::DoubleArrayAhoCorasick;
9use rancor::Fallible;
10use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
11use rkyv::{Archive, Deserialize, Place, Serialize};
12
13pub struct AsBytes;
15
16impl ArchiveWith<Automaton> for AsBytes {
17 type Archived = <Vec<u8> as Archive>::Archived;
18 type Resolver = <Vec<u8> as Archive>::Resolver;
19
20 fn resolve_with(field: &Automaton, resolver: Self::Resolver, out: Place<Self::Archived>) {
21 field.serialize_bytes().resolve(resolver, out);
22 }
23}
24
25impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized> SerializeWith<Automaton, S>
26 for AsBytes
27{
28 fn serialize_with(field: &Automaton, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
29 field.serialize_bytes().serialize(serializer)
30 }
31}
32
33impl<D: Fallible + ?Sized> DeserializeWith<<Vec<u8> as Archive>::Archived, Automaton, D> for AsBytes
34where
35 <Vec<u8> as Archive>::Archived: Deserialize<Vec<u8>, D>,
36{
37 fn deserialize_with(
38 field: &<Vec<u8> as Archive>::Archived,
39 deserializer: &mut D,
40 ) -> Result<Automaton, D::Error> {
41 let bytes: Vec<u8> = field.deserialize(deserializer)?;
42 Ok(Automaton::deserialize_unchecked(&bytes))
43 }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
48pub struct Match {
49 pub pattern: usize,
51 pub start: usize,
53 pub end: usize,
55}
56
57pub struct Automaton {
62 inner: DoubleArrayAhoCorasick<u32>,
63}
64
65impl std::fmt::Debug for Automaton {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("Automaton")
68 .field("num_states", &self.inner.num_states())
69 .field("heap_bytes", &self.inner.heap_bytes())
70 .finish()
71 }
72}
73
74impl Clone for Automaton {
75 fn clone(&self) -> Self {
76 let bytes = self.inner.serialize();
77 Self::deserialize_unchecked(&bytes)
78 }
79}
80
81impl Automaton {
82 pub fn empty() -> Self {
87 let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
90 match DoubleArrayAhoCorasick::new([dummy_pattern]) {
91 Ok(ac) => Self { inner: ac },
92 Err(_) => panic!("Failed to create empty automaton"),
93 }
94 }
95
96 #[allow(dead_code)]
100 pub fn build(patterns: &[&[u8]]) -> Self {
101 if patterns.is_empty() {
102 return Self::empty();
103 }
104 let non_empty: Vec<&[u8]> = patterns.iter().copied().filter(|p| !p.is_empty()).collect();
106 if non_empty.is_empty() {
107 return Self::empty();
108 }
109 match DoubleArrayAhoCorasick::new(non_empty) {
110 Ok(ac) => Self { inner: ac },
111 Err(_) => Self::empty(),
112 }
113 }
114
115 pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
125 FindOverlappingIter::new(&self.inner, haystack)
126 }
127
128 pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
133 let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
134 Self { inner: ac }
135 }
136
137 #[allow(dead_code)]
139 pub fn num_states(&self) -> usize {
140 self.inner.num_states()
141 }
142
143 #[allow(dead_code)]
145 pub fn heap_bytes(&self) -> usize {
146 self.inner.heap_bytes()
147 }
148
149 pub fn serialize_bytes(&self) -> Vec<u8> {
151 self.inner.serialize()
152 }
153}
154
155impl Default for Automaton {
156 fn default() -> Self {
157 Self::empty()
158 }
159}
160
161pub struct FindOverlappingIter {
170 inner: std::vec::IntoIter<daachorse::Match<u32>>,
171}
172
173impl FindOverlappingIter {
174 fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
175 let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
176 Self {
177 inner: matches.into_iter(),
178 }
179 }
180}
181
182impl Iterator for FindOverlappingIter {
183 type Item = Match;
184
185 fn next(&mut self) -> Option<Self::Item> {
186 loop {
187 let m = self.inner.next()?;
188 if m.start() % 2 == 0 {
191 return Some(Match {
192 pattern: m.value() as usize,
193 start: m.start(),
194 end: m.end(),
195 });
196 }
197 }
199 }
200}
201
202pub struct AutomatonBuilder {
206 patterns: Vec<Vec<u8>>,
207}
208
209impl AutomatonBuilder {
210 pub fn new() -> Self {
212 Self {
213 patterns: Vec::new(),
214 }
215 }
216
217 pub fn add_pattern(&mut self, pattern: &[u8]) {
221 if !pattern.is_empty() {
222 self.patterns.push(pattern.to_vec());
223 }
224 }
225
226 pub fn build(self) -> Automaton {
231 use std::collections::HashSet;
232
233 if self.patterns.is_empty() {
234 return Automaton::empty();
235 }
236
237 let mut seen: HashSet<Vec<u8>> = HashSet::new();
239 let mut unique_patterns: Vec<&[u8]> = Vec::new();
240 for pattern in &self.patterns {
241 if seen.insert(pattern.clone()) {
242 unique_patterns.push(pattern.as_slice());
243 }
244 }
245
246 if unique_patterns.is_empty() {
247 return Automaton::empty();
248 }
249
250 match DoubleArrayAhoCorasick::new(unique_patterns) {
251 Ok(ac) => Automaton { inner: ac },
252 Err(_) => Automaton::empty(),
253 }
254 }
255}
256
257impl Default for AutomatonBuilder {
258 fn default() -> Self {
259 Self::new()
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_empty_automaton() {
269 let ac = Automaton::empty();
270 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
271 assert!(matches.is_empty());
272 }
273
274 #[test]
275 fn test_build_with_patterns() {
276 let patterns: Vec<&[u8]> = vec![b"hello", b"world"];
277 let ac = Automaton::build(&patterns);
278 let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
279 assert_eq!(matches.len(), 2);
280 }
281
282 #[test]
283 fn test_token_boundary_filtering() {
284 let pattern: &[u8] = &[31, 49];
286 let ac = Automaton::build(&[pattern]);
287
288 let haystack: &[u8] = &[109, 31, 49, 74];
292 let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
293 assert!(
294 matches.is_empty(),
295 "Should not match across token boundaries"
296 );
297 }
298
299 #[test]
300 fn test_valid_token_match() {
301 let pattern: &[u8] = &[31, 49];
302 let ac = Automaton::build(&[pattern]);
303
304 let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
306 let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
307 assert_eq!(matches.len(), 1);
308 assert_eq!(matches[0].start, 2);
309 assert_eq!(matches[0].end, 4);
310 }
311
312 #[test]
313 fn test_builder() {
314 let mut builder = AutomatonBuilder::new();
315 builder.add_pattern(b"hello");
316 builder.add_pattern(b"world");
317 let ac = builder.build();
318
319 let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
320 assert_eq!(matches.len(), 2);
321 }
322
323 #[test]
324 fn test_builder_empty_patterns() {
325 let builder = AutomatonBuilder::new();
326 let ac = builder.build();
327 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
328 assert!(matches.is_empty());
329 }
330
331 #[test]
332 fn test_builder_skips_empty_patterns() {
333 let mut builder = AutomatonBuilder::new();
334 builder.add_pattern(b"");
335 builder.add_pattern(b"hello");
336 builder.add_pattern(b"");
337 let ac = builder.build();
338
339 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
340 assert_eq!(matches.len(), 1);
341 }
342
343 #[test]
344 fn test_serialize_deserialize() {
345 let patterns: Vec<&[u8]> = vec![b"hello", b"world", b"test"];
346 let ac1 = Automaton::build(&patterns);
347
348 let serialized = ac1.inner.serialize();
349 let ac2 = Automaton::deserialize_unchecked(&serialized);
350
351 let haystack = b"hello world test";
352 let matches1: Vec<_> = ac1.find_overlapping_iter(haystack).collect();
353 let matches2: Vec<_> = ac2.find_overlapping_iter(haystack).collect();
354
355 assert_eq!(matches1.len(), matches2.len());
356 for (m1, m2) in matches1.iter().zip(matches2.iter()) {
357 assert_eq!(m1.pattern, m2.pattern);
358 assert_eq!(m1.start, m2.start);
359 assert_eq!(m1.end, m2.end);
360 }
361 }
362
363 #[test]
364 fn test_overlapping_matches() {
365 let patterns: Vec<&[u8]> = vec![b"ab", b"bc", b"abc"];
366 let ac = Automaton::build(&patterns);
367
368 let matches: Vec<_> = ac.find_overlapping_iter(b"abc").collect();
369 assert!(matches.len() >= 2);
371 }
372}