Skip to main content

provenant/license_detection/
automaton.rs

1//! Aho-Corasick automaton wrapper using daachorse.
2//!
3//! This module provides a `DoubleArrayAhoCorasick`-based automaton that is
4//! significantly smaller than the aho-corasick crate's implementation.
5//! The daachorse library provides ~85% smaller binary size and built-in
6//! serialization support.
7
8use daachorse::DoubleArrayAhoCorasick;
9
10/// A match found by the automaton.
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct Match {
13    /// Pattern ID (index into the original pattern list).
14    pub pattern: usize,
15    /// Start position in haystack (bytes, inclusive).
16    pub start: usize,
17    /// End position in haystack (bytes, exclusive).
18    pub end: usize,
19}
20
21/// Aho-Corasick automaton using daachorse's double-array implementation.
22///
23/// This wrapper provides the same interface as the previous FrozenNfa
24/// but with significantly smaller memory footprint and serialization support.
25pub struct Automaton {
26    inner: DoubleArrayAhoCorasick<u32>,
27}
28
29impl std::fmt::Debug for Automaton {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("Automaton")
32            .field("num_states", &self.inner.num_states())
33            .field("heap_bytes", &self.inner.heap_bytes())
34            .finish()
35    }
36}
37
38impl Clone for Automaton {
39    fn clone(&self) -> Self {
40        let bytes = self.inner.serialize();
41        Self::deserialize_unchecked(&bytes)
42    }
43}
44
45impl Automaton {
46    /// Create a new empty automaton.
47    ///
48    /// Since daachorse requires at least one non-empty pattern, we use a
49    /// dummy pattern that will never match in practice (a unique byte sequence).
50    pub fn empty() -> Self {
51        // Use a very unlikely byte sequence as a sentinel pattern
52        // This will match but never in our token-encoded data
53        let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
54        match DoubleArrayAhoCorasick::new([dummy_pattern]) {
55            Ok(ac) => Self { inner: ac },
56            Err(_) => panic!("Failed to create empty automaton"),
57        }
58    }
59
60    /// Build an automaton from patterns.
61    ///
62    /// Each pattern is a byte slice. Patterns are assigned IDs in order.
63    #[allow(dead_code)]
64    pub fn build(patterns: &[&[u8]]) -> Self {
65        if patterns.is_empty() {
66            return Self::empty();
67        }
68        // Filter out empty patterns - daachorse doesn't support them
69        let non_empty: Vec<&[u8]> = patterns.iter().copied().filter(|p| !p.is_empty()).collect();
70        if non_empty.is_empty() {
71            return Self::empty();
72        }
73        match DoubleArrayAhoCorasick::new(non_empty) {
74            Ok(ac) => Self { inner: ac },
75            Err(_) => Self::empty(),
76        }
77    }
78
79    /// Find all overlapping matches in the haystack.
80    ///
81    /// Returns an iterator that yields all matches found in the haystack,
82    /// including overlapping matches. The matches are yielded in order of
83    /// their end position.
84    ///
85    /// **Important**: This filters matches to only those starting at even
86    /// byte positions (token boundaries). Each token is encoded as 2 bytes,
87    /// so matches starting at odd byte positions would span token boundaries.
88    pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
89        FindOverlappingIter::new(&self.inner, haystack)
90    }
91
92    /// Deserialize an automaton from bytes.
93    ///
94    /// # Safety
95    /// The bytes must be valid serialized data from the underlying daachorse automaton.
96    pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
97        let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
98        Self { inner: ac }
99    }
100
101    /// Get the number of states in the automaton.
102    #[allow(dead_code)]
103    pub fn num_states(&self) -> usize {
104        self.inner.num_states()
105    }
106
107    /// Get the memory usage in bytes.
108    #[allow(dead_code)]
109    pub fn heap_bytes(&self) -> usize {
110        self.inner.heap_bytes()
111    }
112}
113
114impl Default for Automaton {
115    fn default() -> Self {
116        Self::empty()
117    }
118}
119
120/// Iterator over all overlapping matches in a haystack.
121///
122/// This iterator finds all matches, including those that overlap, by
123/// continuing to search after each match rather than skipping past it.
124///
125/// **Token Boundary Filtering**: This iterator only yields matches that
126/// start at even byte positions. Since each token is encoded as 2 bytes,
127/// matches at odd positions would incorrectly span token boundaries.
128pub struct FindOverlappingIter {
129    inner: std::vec::IntoIter<daachorse::Match<u32>>,
130}
131
132impl FindOverlappingIter {
133    fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
134        let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
135        Self {
136            inner: matches.into_iter(),
137        }
138    }
139}
140
141impl Iterator for FindOverlappingIter {
142    type Item = Match;
143
144    fn next(&mut self) -> Option<Self::Item> {
145        loop {
146            let m = self.inner.next()?;
147            // Token boundary check: each token is 2 bytes, so matches must
148            // start at even byte positions. Odd positions would span tokens.
149            if m.start() % 2 == 0 {
150                return Some(Match {
151                    pattern: m.value() as usize,
152                    start: m.start(),
153                    end: m.end(),
154                });
155            }
156            // Skip matches at odd byte positions (invalid token boundaries)
157        }
158    }
159}
160
161/// Builder for constructing automatons incrementally.
162///
163/// This mirrors the `FrozenNfaBuilder` interface for compatibility.
164pub struct AutomatonBuilder {
165    patterns: Vec<Vec<u8>>,
166}
167
168impl AutomatonBuilder {
169    /// Create a new builder.
170    pub fn new() -> Self {
171        Self {
172            patterns: Vec::new(),
173        }
174    }
175
176    /// Add a pattern to the automaton.
177    ///
178    /// Empty patterns are skipped.
179    pub fn add_pattern(&mut self, pattern: &[u8]) {
180        if !pattern.is_empty() {
181            self.patterns.push(pattern.to_vec());
182        }
183    }
184
185    /// Build the automaton.
186    ///
187    /// Deduplicates patterns and assigns sequential IDs (0, 1, 2, ...).
188    /// The caller must maintain their own mapping from pattern_id to rule IDs.
189    pub fn build(self) -> Automaton {
190        use std::collections::HashSet;
191
192        if self.patterns.is_empty() {
193            return Automaton::empty();
194        }
195
196        // Deduplicate patterns - daachorse rejects duplicates
197        let mut seen: HashSet<Vec<u8>> = HashSet::new();
198        let mut unique_patterns: Vec<&[u8]> = Vec::new();
199        for pattern in &self.patterns {
200            if seen.insert(pattern.clone()) {
201                unique_patterns.push(pattern.as_slice());
202            }
203        }
204
205        if unique_patterns.is_empty() {
206            return Automaton::empty();
207        }
208
209        match DoubleArrayAhoCorasick::new(unique_patterns) {
210            Ok(ac) => Automaton { inner: ac },
211            Err(_) => Automaton::empty(),
212        }
213    }
214}
215
216impl Default for AutomatonBuilder {
217    fn default() -> Self {
218        Self::new()
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_empty_automaton() {
228        let ac = Automaton::empty();
229        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
230        assert!(matches.is_empty());
231    }
232
233    #[test]
234    fn test_build_with_patterns() {
235        let patterns: Vec<&[u8]> = vec![b"hello", b"world"];
236        let ac = Automaton::build(&patterns);
237        let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
238        assert_eq!(matches.len(), 2);
239    }
240
241    #[test]
242    fn test_token_boundary_filtering() {
243        // Pattern: [31, 49] (token 12575 in little-endian)
244        let pattern: &[u8] = &[31, 49];
245        let ac = Automaton::build(&[pattern]);
246
247        // Haystack: [109, 31, 49, 74] = tokens [8045, 18993]
248        // The pattern [31, 49] appears at bytes 1-2 (odd position)
249        // which would span token boundaries - should NOT match
250        let haystack: &[u8] = &[109, 31, 49, 74];
251        let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
252        assert!(
253            matches.is_empty(),
254            "Should not match across token boundaries"
255        );
256    }
257
258    #[test]
259    fn test_valid_token_match() {
260        let pattern: &[u8] = &[31, 49];
261        let ac = Automaton::build(&[pattern]);
262
263        // Haystack with pattern at even position (valid token boundary)
264        let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
265        let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
266        assert_eq!(matches.len(), 1);
267        assert_eq!(matches[0].start, 2);
268        assert_eq!(matches[0].end, 4);
269    }
270
271    #[test]
272    fn test_builder() {
273        let mut builder = AutomatonBuilder::new();
274        builder.add_pattern(b"hello");
275        builder.add_pattern(b"world");
276        let ac = builder.build();
277
278        let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
279        assert_eq!(matches.len(), 2);
280    }
281
282    #[test]
283    fn test_builder_empty_patterns() {
284        let builder = AutomatonBuilder::new();
285        let ac = builder.build();
286        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
287        assert!(matches.is_empty());
288    }
289
290    #[test]
291    fn test_builder_skips_empty_patterns() {
292        let mut builder = AutomatonBuilder::new();
293        builder.add_pattern(b"");
294        builder.add_pattern(b"hello");
295        builder.add_pattern(b"");
296        let ac = builder.build();
297
298        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
299        assert_eq!(matches.len(), 1);
300    }
301
302    #[test]
303    fn test_serialize_deserialize() {
304        let patterns: Vec<&[u8]> = vec![b"hello", b"world", b"test"];
305        let ac1 = Automaton::build(&patterns);
306
307        let serialized = ac1.inner.serialize();
308        let ac2 = Automaton::deserialize_unchecked(&serialized);
309
310        let haystack = b"hello world test";
311        let matches1: Vec<_> = ac1.find_overlapping_iter(haystack).collect();
312        let matches2: Vec<_> = ac2.find_overlapping_iter(haystack).collect();
313
314        assert_eq!(matches1.len(), matches2.len());
315        for (m1, m2) in matches1.iter().zip(matches2.iter()) {
316            assert_eq!(m1.pattern, m2.pattern);
317            assert_eq!(m1.start, m2.start);
318            assert_eq!(m1.end, m2.end);
319        }
320    }
321
322    #[test]
323    fn test_overlapping_matches() {
324        let patterns: Vec<&[u8]> = vec![b"ab", b"bc", b"abc"];
325        let ac = Automaton::build(&patterns);
326
327        let matches: Vec<_> = ac.find_overlapping_iter(b"abc").collect();
328        // Should find "ab", "abc", and "bc" (all overlapping)
329        assert!(matches.len() >= 2);
330    }
331}