Skip to main content

provenant/license_detection/
automaton.rs

1//! Aho-Corasick automaton wrapper using daachorse.
2//!
3//! This module provides a `DoubleArrayAhoCorasick`-based automaton that is
4//! significantly smaller than the aho-corasick crate's implementation.
5//! The daachorse library provides ~85% smaller binary size and built-in
6//! serialization support.
7
8use daachorse::DoubleArrayAhoCorasick;
9use rancor::Fallible;
10use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
11use rkyv::{Archive, Deserialize, Place, Serialize};
12
13/// rkyv `with` adapter that archives an `Automaton` as its serialized byte form.
14pub struct AsBytes;
15
16impl ArchiveWith<Automaton> for AsBytes {
17    type Archived = <Vec<u8> as Archive>::Archived;
18    type Resolver = <Vec<u8> as Archive>::Resolver;
19
20    fn resolve_with(field: &Automaton, resolver: Self::Resolver, out: Place<Self::Archived>) {
21        field.serialize_bytes().resolve(resolver, out);
22    }
23}
24
25impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized> SerializeWith<Automaton, S>
26    for AsBytes
27{
28    fn serialize_with(field: &Automaton, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
29        field.serialize_bytes().serialize(serializer)
30    }
31}
32
33impl<D: Fallible + ?Sized> DeserializeWith<<Vec<u8> as Archive>::Archived, Automaton, D> for AsBytes
34where
35    <Vec<u8> as Archive>::Archived: Deserialize<Vec<u8>, D>,
36{
37    fn deserialize_with(
38        field: &<Vec<u8> as Archive>::Archived,
39        deserializer: &mut D,
40    ) -> Result<Automaton, D::Error> {
41        let bytes: Vec<u8> = field.deserialize(deserializer)?;
42        Ok(Automaton::deserialize_unchecked(&bytes))
43    }
44}
45
46/// A match found by the automaton.
47#[derive(Debug, Clone, PartialEq, Eq)]
48pub struct Match {
49    /// Pattern ID (index into the original pattern list).
50    pub pattern: usize,
51    /// Start position in haystack (bytes, inclusive).
52    pub start: usize,
53    /// End position in haystack (bytes, exclusive).
54    pub end: usize,
55}
56
57/// Aho-Corasick automaton using daachorse's double-array implementation.
58///
59/// This wrapper provides the same interface as the previous FrozenNfa
60/// but with significantly smaller memory footprint and serialization support.
61pub struct Automaton {
62    inner: DoubleArrayAhoCorasick<u32>,
63}
64
65impl std::fmt::Debug for Automaton {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("Automaton")
68            .field("num_states", &self.inner.num_states())
69            .field("heap_bytes", &self.inner.heap_bytes())
70            .finish()
71    }
72}
73
74impl Clone for Automaton {
75    fn clone(&self) -> Self {
76        let bytes = self.inner.serialize();
77        Self::deserialize_unchecked(&bytes)
78    }
79}
80
81impl Automaton {
82    /// Create a new empty automaton.
83    ///
84    /// Since daachorse requires at least one non-empty pattern, we use a
85    /// dummy pattern that will never match in practice (a unique byte sequence).
86    pub fn empty() -> Self {
87        // Use a very unlikely byte sequence as a sentinel pattern
88        // This will match but never in our token-encoded data
89        let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
90        match DoubleArrayAhoCorasick::new([dummy_pattern]) {
91            Ok(ac) => Self { inner: ac },
92            Err(_) => panic!("Failed to create empty automaton"),
93        }
94    }
95
96    /// Build an automaton from patterns.
97    ///
98    /// Each pattern is a byte slice. Patterns are assigned IDs in order.
99    #[allow(dead_code)]
100    pub fn build(patterns: &[&[u8]]) -> Self {
101        if patterns.is_empty() {
102            return Self::empty();
103        }
104        // Filter out empty patterns - daachorse doesn't support them
105        let non_empty: Vec<&[u8]> = patterns.iter().copied().filter(|p| !p.is_empty()).collect();
106        if non_empty.is_empty() {
107            return Self::empty();
108        }
109        match DoubleArrayAhoCorasick::new(non_empty) {
110            Ok(ac) => Self { inner: ac },
111            Err(_) => Self::empty(),
112        }
113    }
114
115    /// Find all overlapping matches in the haystack.
116    ///
117    /// Returns an iterator that yields all matches found in the haystack,
118    /// including overlapping matches. The matches are yielded in order of
119    /// their end position.
120    ///
121    /// **Important**: This filters matches to only those starting at even
122    /// byte positions (token boundaries). Each token is encoded as 2 bytes,
123    /// so matches starting at odd byte positions would span token boundaries.
124    pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
125        FindOverlappingIter::new(&self.inner, haystack)
126    }
127
128    /// Deserialize an automaton from bytes.
129    ///
130    /// # Safety
131    /// The bytes must be valid serialized data from the underlying daachorse automaton.
132    pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
133        let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
134        Self { inner: ac }
135    }
136
137    /// Get the number of states in the automaton.
138    #[allow(dead_code)]
139    pub fn num_states(&self) -> usize {
140        self.inner.num_states()
141    }
142
143    /// Get the memory usage in bytes.
144    #[allow(dead_code)]
145    pub fn heap_bytes(&self) -> usize {
146        self.inner.heap_bytes()
147    }
148
149    /// Serialize the automaton to a byte vector.
150    pub fn serialize_bytes(&self) -> Vec<u8> {
151        self.inner.serialize()
152    }
153}
154
155impl Default for Automaton {
156    fn default() -> Self {
157        Self::empty()
158    }
159}
160
161/// Iterator over all overlapping matches in a haystack.
162///
163/// This iterator finds all matches, including those that overlap, by
164/// continuing to search after each match rather than skipping past it.
165///
166/// **Token Boundary Filtering**: This iterator only yields matches that
167/// start at even byte positions. Since each token is encoded as 2 bytes,
168/// matches at odd positions would incorrectly span token boundaries.
169pub struct FindOverlappingIter {
170    inner: std::vec::IntoIter<daachorse::Match<u32>>,
171}
172
173impl FindOverlappingIter {
174    fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
175        let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
176        Self {
177            inner: matches.into_iter(),
178        }
179    }
180}
181
182impl Iterator for FindOverlappingIter {
183    type Item = Match;
184
185    fn next(&mut self) -> Option<Self::Item> {
186        loop {
187            let m = self.inner.next()?;
188            // Token boundary check: each token is 2 bytes, so matches must
189            // start at even byte positions. Odd positions would span tokens.
190            if m.start() % 2 == 0 {
191                return Some(Match {
192                    pattern: m.value() as usize,
193                    start: m.start(),
194                    end: m.end(),
195                });
196            }
197            // Skip matches at odd byte positions (invalid token boundaries)
198        }
199    }
200}
201
202/// Builder for constructing automatons incrementally.
203///
204/// This mirrors the `FrozenNfaBuilder` interface for compatibility.
205pub struct AutomatonBuilder {
206    patterns: Vec<Vec<u8>>,
207}
208
209impl AutomatonBuilder {
210    /// Create a new builder.
211    pub fn new() -> Self {
212        Self {
213            patterns: Vec::new(),
214        }
215    }
216
217    /// Add a pattern to the automaton.
218    ///
219    /// Empty patterns are skipped.
220    pub fn add_pattern(&mut self, pattern: &[u8]) {
221        if !pattern.is_empty() {
222            self.patterns.push(pattern.to_vec());
223        }
224    }
225
226    /// Build the automaton.
227    ///
228    /// Deduplicates patterns and assigns sequential IDs (0, 1, 2, ...).
229    /// The caller must maintain their own mapping from pattern_id to rule IDs.
230    pub fn build(self) -> Automaton {
231        use std::collections::HashSet;
232
233        if self.patterns.is_empty() {
234            return Automaton::empty();
235        }
236
237        // Deduplicate patterns - daachorse rejects duplicates
238        let mut seen: HashSet<Vec<u8>> = HashSet::new();
239        let mut unique_patterns: Vec<&[u8]> = Vec::new();
240        for pattern in &self.patterns {
241            if seen.insert(pattern.clone()) {
242                unique_patterns.push(pattern.as_slice());
243            }
244        }
245
246        if unique_patterns.is_empty() {
247            return Automaton::empty();
248        }
249
250        match DoubleArrayAhoCorasick::new(unique_patterns) {
251            Ok(ac) => Automaton { inner: ac },
252            Err(_) => Automaton::empty(),
253        }
254    }
255}
256
257impl Default for AutomatonBuilder {
258    fn default() -> Self {
259        Self::new()
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn test_empty_automaton() {
269        let ac = Automaton::empty();
270        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
271        assert!(matches.is_empty());
272    }
273
274    #[test]
275    fn test_build_with_patterns() {
276        let patterns: Vec<&[u8]> = vec![b"hello", b"world"];
277        let ac = Automaton::build(&patterns);
278        let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
279        assert_eq!(matches.len(), 2);
280    }
281
282    #[test]
283    fn test_token_boundary_filtering() {
284        // Pattern: [31, 49] (token 12575 in little-endian)
285        let pattern: &[u8] = &[31, 49];
286        let ac = Automaton::build(&[pattern]);
287
288        // Haystack: [109, 31, 49, 74] = tokens [8045, 18993]
289        // The pattern [31, 49] appears at bytes 1-2 (odd position)
290        // which would span token boundaries - should NOT match
291        let haystack: &[u8] = &[109, 31, 49, 74];
292        let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
293        assert!(
294            matches.is_empty(),
295            "Should not match across token boundaries"
296        );
297    }
298
299    #[test]
300    fn test_valid_token_match() {
301        let pattern: &[u8] = &[31, 49];
302        let ac = Automaton::build(&[pattern]);
303
304        // Haystack with pattern at even position (valid token boundary)
305        let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
306        let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
307        assert_eq!(matches.len(), 1);
308        assert_eq!(matches[0].start, 2);
309        assert_eq!(matches[0].end, 4);
310    }
311
312    #[test]
313    fn test_builder() {
314        let mut builder = AutomatonBuilder::new();
315        builder.add_pattern(b"hello");
316        builder.add_pattern(b"world");
317        let ac = builder.build();
318
319        let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
320        assert_eq!(matches.len(), 2);
321    }
322
323    #[test]
324    fn test_builder_empty_patterns() {
325        let builder = AutomatonBuilder::new();
326        let ac = builder.build();
327        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
328        assert!(matches.is_empty());
329    }
330
331    #[test]
332    fn test_builder_skips_empty_patterns() {
333        let mut builder = AutomatonBuilder::new();
334        builder.add_pattern(b"");
335        builder.add_pattern(b"hello");
336        builder.add_pattern(b"");
337        let ac = builder.build();
338
339        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
340        assert_eq!(matches.len(), 1);
341    }
342
343    #[test]
344    fn test_serialize_deserialize() {
345        let patterns: Vec<&[u8]> = vec![b"hello", b"world", b"test"];
346        let ac1 = Automaton::build(&patterns);
347
348        let serialized = ac1.inner.serialize();
349        let ac2 = Automaton::deserialize_unchecked(&serialized);
350
351        let haystack = b"hello world test";
352        let matches1: Vec<_> = ac1.find_overlapping_iter(haystack).collect();
353        let matches2: Vec<_> = ac2.find_overlapping_iter(haystack).collect();
354
355        assert_eq!(matches1.len(), matches2.len());
356        for (m1, m2) in matches1.iter().zip(matches2.iter()) {
357            assert_eq!(m1.pattern, m2.pattern);
358            assert_eq!(m1.start, m2.start);
359            assert_eq!(m1.end, m2.end);
360        }
361    }
362
363    #[test]
364    fn test_overlapping_matches() {
365        let patterns: Vec<&[u8]> = vec![b"ab", b"bc", b"abc"];
366        let ac = Automaton::build(&patterns);
367
368        let matches: Vec<_> = ac.find_overlapping_iter(b"abc").collect();
369        // Should find "ab", "abc", and "bc" (all overlapping)
370        assert!(matches.len() >= 2);
371    }
372}