Skip to main content

provenant/license_detection/
automaton.rs

1//! Aho-Corasick automaton wrapper using daachorse.
2//!
3//! This module provides a `DoubleArrayAhoCorasick`-based automaton that is
4//! significantly smaller than the aho-corasick crate's implementation.
5//! The daachorse library provides ~85% smaller binary size and built-in
6//! serialization support.
7
8use daachorse::DoubleArrayAhoCorasick;
9
10/// A match found by the automaton.
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct Match {
13    /// Pattern ID (index into the original pattern list).
14    pub pattern: usize,
15    /// Start position in haystack (bytes, inclusive).
16    pub start: usize,
17    /// End position in haystack (bytes, exclusive).
18    pub end: usize,
19}
20
21/// Aho-Corasick automaton using daachorse's double-array implementation.
22///
23/// This wrapper provides the same interface as the previous FrozenNfa
24/// but with significantly smaller memory footprint and serialization support.
25pub struct Automaton {
26    inner: DoubleArrayAhoCorasick<u32>,
27}
28
29impl std::fmt::Debug for Automaton {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("Automaton")
32            .field("num_states", &self.inner.num_states())
33            .field("heap_bytes", &self.inner.heap_bytes())
34            .finish()
35    }
36}
37
38impl Clone for Automaton {
39    fn clone(&self) -> Self {
40        let bytes = self.inner.serialize();
41        Self::deserialize_unchecked(&bytes)
42    }
43}
44
45impl Automaton {
46    /// Create a new empty automaton.
47    ///
48    /// Since daachorse requires at least one non-empty pattern, we use a
49    /// dummy pattern that will never match in practice (a unique byte sequence).
50    pub fn empty() -> Self {
51        // Use a very unlikely byte sequence as a sentinel pattern
52        // This will match but never in our token-encoded data
53        let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
54        match DoubleArrayAhoCorasick::new([dummy_pattern]) {
55            Ok(ac) => Self { inner: ac },
56            Err(_) => panic!("Failed to create empty automaton"),
57        }
58    }
59
60    /// Build an automaton from patterns.
61    ///
62    /// Each pattern is a byte slice. Patterns are assigned IDs in order.
63    #[allow(dead_code)]
64    pub fn build(patterns: &[&[u8]]) -> Self {
65        if patterns.is_empty() {
66            return Self::empty();
67        }
68        // Filter out empty patterns - daachorse doesn't support them
69        let non_empty: Vec<&[u8]> = patterns.iter().copied().filter(|p| !p.is_empty()).collect();
70        if non_empty.is_empty() {
71            return Self::empty();
72        }
73        match DoubleArrayAhoCorasick::new(non_empty) {
74            Ok(ac) => Self { inner: ac },
75            Err(_) => Self::empty(),
76        }
77    }
78
79    /// Find all overlapping matches in the haystack.
80    ///
81    /// Returns an iterator that yields all matches found in the haystack,
82    /// including overlapping matches. The matches are yielded in order of
83    /// their end position.
84    ///
85    /// **Important**: This filters matches to only those starting at even
86    /// byte positions (token boundaries). Each token is encoded as 2 bytes,
87    /// so matches starting at odd byte positions would span token boundaries.
88    pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
89        FindOverlappingIter::new(&self.inner, haystack)
90    }
91
92    /// Deserialize an automaton from bytes.
93    ///
94    /// # Safety
95    /// The bytes must be valid serialized data from the underlying daachorse automaton.
96    pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
97        let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
98        Self { inner: ac }
99    }
100
101    /// Get the number of states in the automaton.
102    #[allow(dead_code)]
103    pub fn num_states(&self) -> usize {
104        self.inner.num_states()
105    }
106
107    /// Get the memory usage in bytes.
108    #[allow(dead_code)]
109    pub fn heap_bytes(&self) -> usize {
110        self.inner.heap_bytes()
111    }
112
113    /// Serialize the automaton to a byte vector.
114    pub fn serialize_bytes(&self) -> Vec<u8> {
115        self.inner.serialize()
116    }
117}
118
119impl Default for Automaton {
120    fn default() -> Self {
121        Self::empty()
122    }
123}
124
125/// Iterator over all overlapping matches in a haystack.
126///
127/// This iterator finds all matches, including those that overlap, by
128/// continuing to search after each match rather than skipping past it.
129///
130/// **Token Boundary Filtering**: This iterator only yields matches that
131/// start at even byte positions. Since each token is encoded as 2 bytes,
132/// matches at odd positions would incorrectly span token boundaries.
133pub struct FindOverlappingIter {
134    inner: std::vec::IntoIter<daachorse::Match<u32>>,
135}
136
137impl FindOverlappingIter {
138    fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
139        let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
140        Self {
141            inner: matches.into_iter(),
142        }
143    }
144}
145
146impl Iterator for FindOverlappingIter {
147    type Item = Match;
148
149    fn next(&mut self) -> Option<Self::Item> {
150        loop {
151            let m = self.inner.next()?;
152            // Token boundary check: each token is 2 bytes, so matches must
153            // start at even byte positions. Odd positions would span tokens.
154            if m.start() % 2 == 0 {
155                return Some(Match {
156                    pattern: m.value() as usize,
157                    start: m.start(),
158                    end: m.end(),
159                });
160            }
161            // Skip matches at odd byte positions (invalid token boundaries)
162        }
163    }
164}
165
166/// Builder for constructing automatons incrementally.
167///
168/// This mirrors the `FrozenNfaBuilder` interface for compatibility.
169pub struct AutomatonBuilder {
170    patterns: Vec<Vec<u8>>,
171}
172
173impl AutomatonBuilder {
174    /// Create a new builder.
175    pub fn new() -> Self {
176        Self {
177            patterns: Vec::new(),
178        }
179    }
180
181    /// Add a pattern to the automaton.
182    ///
183    /// Empty patterns are skipped.
184    pub fn add_pattern(&mut self, pattern: &[u8]) {
185        if !pattern.is_empty() {
186            self.patterns.push(pattern.to_vec());
187        }
188    }
189
190    /// Build the automaton.
191    ///
192    /// Deduplicates patterns and assigns sequential IDs (0, 1, 2, ...).
193    /// The caller must maintain their own mapping from pattern_id to rule IDs.
194    pub fn build(self) -> Automaton {
195        use std::collections::HashSet;
196
197        if self.patterns.is_empty() {
198            return Automaton::empty();
199        }
200
201        // Deduplicate patterns - daachorse rejects duplicates
202        let mut seen: HashSet<Vec<u8>> = HashSet::new();
203        let mut unique_patterns: Vec<&[u8]> = Vec::new();
204        for pattern in &self.patterns {
205            if seen.insert(pattern.clone()) {
206                unique_patterns.push(pattern.as_slice());
207            }
208        }
209
210        if unique_patterns.is_empty() {
211            return Automaton::empty();
212        }
213
214        match DoubleArrayAhoCorasick::new(unique_patterns) {
215            Ok(ac) => Automaton { inner: ac },
216            Err(_) => Automaton::empty(),
217        }
218    }
219}
220
221impl Default for AutomatonBuilder {
222    fn default() -> Self {
223        Self::new()
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_empty_automaton() {
233        let ac = Automaton::empty();
234        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
235        assert!(matches.is_empty());
236    }
237
238    #[test]
239    fn test_build_with_patterns() {
240        let patterns: Vec<&[u8]> = vec![b"hello", b"world"];
241        let ac = Automaton::build(&patterns);
242        let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
243        assert_eq!(matches.len(), 2);
244    }
245
246    #[test]
247    fn test_token_boundary_filtering() {
248        // Pattern: [31, 49] (token 12575 in little-endian)
249        let pattern: &[u8] = &[31, 49];
250        let ac = Automaton::build(&[pattern]);
251
252        // Haystack: [109, 31, 49, 74] = tokens [8045, 18993]
253        // The pattern [31, 49] appears at bytes 1-2 (odd position)
254        // which would span token boundaries - should NOT match
255        let haystack: &[u8] = &[109, 31, 49, 74];
256        let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
257        assert!(
258            matches.is_empty(),
259            "Should not match across token boundaries"
260        );
261    }
262
263    #[test]
264    fn test_valid_token_match() {
265        let pattern: &[u8] = &[31, 49];
266        let ac = Automaton::build(&[pattern]);
267
268        // Haystack with pattern at even position (valid token boundary)
269        let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
270        let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
271        assert_eq!(matches.len(), 1);
272        assert_eq!(matches[0].start, 2);
273        assert_eq!(matches[0].end, 4);
274    }
275
276    #[test]
277    fn test_builder() {
278        let mut builder = AutomatonBuilder::new();
279        builder.add_pattern(b"hello");
280        builder.add_pattern(b"world");
281        let ac = builder.build();
282
283        let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
284        assert_eq!(matches.len(), 2);
285    }
286
287    #[test]
288    fn test_builder_empty_patterns() {
289        let builder = AutomatonBuilder::new();
290        let ac = builder.build();
291        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
292        assert!(matches.is_empty());
293    }
294
295    #[test]
296    fn test_builder_skips_empty_patterns() {
297        let mut builder = AutomatonBuilder::new();
298        builder.add_pattern(b"");
299        builder.add_pattern(b"hello");
300        builder.add_pattern(b"");
301        let ac = builder.build();
302
303        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
304        assert_eq!(matches.len(), 1);
305    }
306
307    #[test]
308    fn test_serialize_deserialize() {
309        let patterns: Vec<&[u8]> = vec![b"hello", b"world", b"test"];
310        let ac1 = Automaton::build(&patterns);
311
312        let serialized = ac1.inner.serialize();
313        let ac2 = Automaton::deserialize_unchecked(&serialized);
314
315        let haystack = b"hello world test";
316        let matches1: Vec<_> = ac1.find_overlapping_iter(haystack).collect();
317        let matches2: Vec<_> = ac2.find_overlapping_iter(haystack).collect();
318
319        assert_eq!(matches1.len(), matches2.len());
320        for (m1, m2) in matches1.iter().zip(matches2.iter()) {
321            assert_eq!(m1.pattern, m2.pattern);
322            assert_eq!(m1.start, m2.start);
323            assert_eq!(m1.end, m2.end);
324        }
325    }
326
327    #[test]
328    fn test_overlapping_matches() {
329        let patterns: Vec<&[u8]> = vec![b"ab", b"bc", b"abc"];
330        let ac = Automaton::build(&patterns);
331
332        let matches: Vec<_> = ac.find_overlapping_iter(b"abc").collect();
333        // Should find "ab", "abc", and "bc" (all overlapping)
334        assert!(matches.len() >= 2);
335    }
336}