Skip to main content

provenant/license_detection/
automaton.rs

1// SPDX-FileCopyrightText: Provenant contributors
2// SPDX-License-Identifier: Apache-2.0
3
4//! Aho-Corasick automaton wrapper using daachorse.
5//!
6//! This module provides a `DoubleArrayAhoCorasick`-based automaton that is
7//! significantly smaller than the aho-corasick crate's implementation.
8//! The daachorse library provides ~85% smaller binary size and built-in
9//! serialization support.
10
11use daachorse::DoubleArrayAhoCorasick;
12use rancor::Fallible;
13use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
14use rkyv::{Archive, Deserialize, Place, Serialize};
15
16/// rkyv `with` adapter that archives an `Automaton` as its serialized byte form.
17pub struct AsBytes;
18
19impl ArchiveWith<Automaton> for AsBytes {
20    type Archived = <Vec<u8> as Archive>::Archived;
21    type Resolver = <Vec<u8> as Archive>::Resolver;
22
23    fn resolve_with(field: &Automaton, resolver: Self::Resolver, out: Place<Self::Archived>) {
24        field.serialize_bytes().resolve(resolver, out);
25    }
26}
27
28impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized> SerializeWith<Automaton, S>
29    for AsBytes
30{
31    fn serialize_with(field: &Automaton, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
32        field.serialize_bytes().serialize(serializer)
33    }
34}
35
36impl<D: Fallible + ?Sized> DeserializeWith<<Vec<u8> as Archive>::Archived, Automaton, D> for AsBytes
37where
38    <Vec<u8> as Archive>::Archived: Deserialize<Vec<u8>, D>,
39{
40    fn deserialize_with(
41        field: &<Vec<u8> as Archive>::Archived,
42        deserializer: &mut D,
43    ) -> Result<Automaton, D::Error> {
44        let bytes: Vec<u8> = field.deserialize(deserializer)?;
45        Ok(Automaton::deserialize_unchecked(&bytes))
46    }
47}
48
49/// A match found by the automaton.
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct Match {
52    /// Pattern ID (index into the original pattern list).
53    pub pattern: usize,
54    /// Start position in haystack (bytes, inclusive).
55    pub start: usize,
56    /// End position in haystack (bytes, exclusive).
57    pub end: usize,
58}
59
60/// Aho-Corasick automaton using daachorse's double-array implementation.
61///
62/// This wrapper provides the same interface as the previous FrozenNfa
63/// but with significantly smaller memory footprint and serialization support.
64pub struct Automaton {
65    inner: DoubleArrayAhoCorasick<u32>,
66}
67
68impl std::fmt::Debug for Automaton {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        f.debug_struct("Automaton")
71            .field("num_states", &self.inner.num_states())
72            .field("heap_bytes", &self.inner.heap_bytes())
73            .finish()
74    }
75}
76
77impl Clone for Automaton {
78    fn clone(&self) -> Self {
79        let bytes = self.inner.serialize();
80        Self::deserialize_unchecked(&bytes)
81    }
82}
83
84impl Automaton {
85    /// Create a new empty automaton.
86    ///
87    /// Since daachorse requires at least one non-empty pattern, we use a
88    /// dummy pattern that will never match in practice (a unique byte sequence).
89    pub fn empty() -> Self {
90        // Use a very unlikely byte sequence as a sentinel pattern
91        // This will match but never in our token-encoded data
92        let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
93        match DoubleArrayAhoCorasick::new([dummy_pattern]) {
94            Ok(ac) => Self { inner: ac },
95            Err(_) => panic!("Failed to create empty automaton"),
96        }
97    }
98
99    /// Build an automaton from patterns.
100    ///
101    /// Each pattern is a byte slice. Patterns are assigned IDs in order.
102    #[allow(dead_code)]
103    pub fn build(patterns: &[&[u8]]) -> Self {
104        if patterns.is_empty() {
105            return Self::empty();
106        }
107        // Filter out empty patterns - daachorse doesn't support them
108        let non_empty: Vec<&[u8]> = patterns.iter().copied().filter(|p| !p.is_empty()).collect();
109        if non_empty.is_empty() {
110            return Self::empty();
111        }
112        match DoubleArrayAhoCorasick::new(non_empty) {
113            Ok(ac) => Self { inner: ac },
114            Err(_) => Self::empty(),
115        }
116    }
117
118    /// Find all overlapping matches in the haystack.
119    ///
120    /// Returns an iterator that yields all matches found in the haystack,
121    /// including overlapping matches. The matches are yielded in order of
122    /// their end position.
123    ///
124    /// **Important**: This filters matches to only those starting at even
125    /// byte positions (token boundaries). Each token is encoded as 2 bytes,
126    /// so matches starting at odd byte positions would span token boundaries.
127    pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
128        FindOverlappingIter::new(&self.inner, haystack)
129    }
130
131    /// Deserialize an automaton from bytes.
132    ///
133    /// # Safety
134    /// The bytes must be valid serialized data from the underlying daachorse automaton.
135    pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
136        let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
137        Self { inner: ac }
138    }
139
140    /// Get the number of states in the automaton.
141    #[allow(dead_code)]
142    pub fn num_states(&self) -> usize {
143        self.inner.num_states()
144    }
145
146    /// Get the memory usage in bytes.
147    #[allow(dead_code)]
148    pub fn heap_bytes(&self) -> usize {
149        self.inner.heap_bytes()
150    }
151
152    /// Serialize the automaton to a byte vector.
153    pub fn serialize_bytes(&self) -> Vec<u8> {
154        self.inner.serialize()
155    }
156}
157
158impl Default for Automaton {
159    fn default() -> Self {
160        Self::empty()
161    }
162}
163
164/// Iterator over all overlapping matches in a haystack.
165///
166/// This iterator finds all matches, including those that overlap, by
167/// continuing to search after each match rather than skipping past it.
168///
169/// **Token Boundary Filtering**: This iterator only yields matches that
170/// start at even byte positions. Since each token is encoded as 2 bytes,
171/// matches at odd positions would incorrectly span token boundaries.
172pub struct FindOverlappingIter {
173    inner: std::vec::IntoIter<daachorse::Match<u32>>,
174}
175
176impl FindOverlappingIter {
177    fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
178        let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
179        Self {
180            inner: matches.into_iter(),
181        }
182    }
183}
184
185impl Iterator for FindOverlappingIter {
186    type Item = Match;
187
188    fn next(&mut self) -> Option<Self::Item> {
189        loop {
190            let m = self.inner.next()?;
191            // Token boundary check: each token is 2 bytes, so matches must
192            // start at even byte positions. Odd positions would span tokens.
193            if m.start() % 2 == 0 {
194                return Some(Match {
195                    pattern: m.value() as usize,
196                    start: m.start(),
197                    end: m.end(),
198                });
199            }
200            // Skip matches at odd byte positions (invalid token boundaries)
201        }
202    }
203}
204
205/// Builder for constructing automatons incrementally.
206///
207/// This mirrors the `FrozenNfaBuilder` interface for compatibility.
208pub struct AutomatonBuilder {
209    patterns: Vec<Vec<u8>>,
210}
211
212impl AutomatonBuilder {
213    /// Create a new builder.
214    pub fn new() -> Self {
215        Self {
216            patterns: Vec::new(),
217        }
218    }
219
220    /// Add a pattern to the automaton.
221    ///
222    /// Empty patterns are skipped.
223    pub fn add_pattern(&mut self, pattern: &[u8]) {
224        if !pattern.is_empty() {
225            self.patterns.push(pattern.to_vec());
226        }
227    }
228
229    /// Build the automaton.
230    ///
231    /// Deduplicates patterns and assigns sequential IDs (0, 1, 2, ...).
232    /// The caller must maintain their own mapping from pattern_id to rule IDs.
233    pub fn build(self) -> Automaton {
234        use std::collections::HashSet;
235
236        if self.patterns.is_empty() {
237            return Automaton::empty();
238        }
239
240        // Deduplicate patterns - daachorse rejects duplicates
241        let mut seen: HashSet<Vec<u8>> = HashSet::new();
242        let mut unique_patterns: Vec<&[u8]> = Vec::new();
243        for pattern in &self.patterns {
244            if seen.insert(pattern.clone()) {
245                unique_patterns.push(pattern.as_slice());
246            }
247        }
248
249        if unique_patterns.is_empty() {
250            return Automaton::empty();
251        }
252
253        match DoubleArrayAhoCorasick::new(unique_patterns) {
254            Ok(ac) => Automaton { inner: ac },
255            Err(_) => Automaton::empty(),
256        }
257    }
258}
259
260impl Default for AutomatonBuilder {
261    fn default() -> Self {
262        Self::new()
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_empty_automaton() {
272        let ac = Automaton::empty();
273        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
274        assert!(matches.is_empty());
275    }
276
277    #[test]
278    fn test_build_with_patterns() {
279        let patterns: Vec<&[u8]> = vec![b"hello", b"world"];
280        let ac = Automaton::build(&patterns);
281        let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
282        assert_eq!(matches.len(), 2);
283    }
284
285    #[test]
286    fn test_token_boundary_filtering() {
287        // Pattern: [31, 49] (token 12575 in little-endian)
288        let pattern: &[u8] = &[31, 49];
289        let ac = Automaton::build(&[pattern]);
290
291        // Haystack: [109, 31, 49, 74] = tokens [8045, 18993]
292        // The pattern [31, 49] appears at bytes 1-2 (odd position)
293        // which would span token boundaries - should NOT match
294        let haystack: &[u8] = &[109, 31, 49, 74];
295        let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
296        assert!(
297            matches.is_empty(),
298            "Should not match across token boundaries"
299        );
300    }
301
302    #[test]
303    fn test_valid_token_match() {
304        let pattern: &[u8] = &[31, 49];
305        let ac = Automaton::build(&[pattern]);
306
307        // Haystack with pattern at even position (valid token boundary)
308        let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
309        let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
310        assert_eq!(matches.len(), 1);
311        assert_eq!(matches[0].start, 2);
312        assert_eq!(matches[0].end, 4);
313    }
314
315    #[test]
316    fn test_builder() {
317        let mut builder = AutomatonBuilder::new();
318        builder.add_pattern(b"hello");
319        builder.add_pattern(b"world");
320        let ac = builder.build();
321
322        let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
323        assert_eq!(matches.len(), 2);
324    }
325
326    #[test]
327    fn test_builder_empty_patterns() {
328        let builder = AutomatonBuilder::new();
329        let ac = builder.build();
330        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
331        assert!(matches.is_empty());
332    }
333
334    #[test]
335    fn test_builder_skips_empty_patterns() {
336        let mut builder = AutomatonBuilder::new();
337        builder.add_pattern(b"");
338        builder.add_pattern(b"hello");
339        builder.add_pattern(b"");
340        let ac = builder.build();
341
342        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
343        assert_eq!(matches.len(), 1);
344    }
345
346    #[test]
347    fn test_serialize_deserialize() {
348        let patterns: Vec<&[u8]> = vec![b"hello", b"world", b"test"];
349        let ac1 = Automaton::build(&patterns);
350
351        let serialized = ac1.inner.serialize();
352        let ac2 = Automaton::deserialize_unchecked(&serialized);
353
354        let haystack = b"hello world test";
355        let matches1: Vec<_> = ac1.find_overlapping_iter(haystack).collect();
356        let matches2: Vec<_> = ac2.find_overlapping_iter(haystack).collect();
357
358        assert_eq!(matches1.len(), matches2.len());
359        for (m1, m2) in matches1.iter().zip(matches2.iter()) {
360            assert_eq!(m1.pattern, m2.pattern);
361            assert_eq!(m1.start, m2.start);
362            assert_eq!(m1.end, m2.end);
363        }
364    }
365
366    #[test]
367    fn test_overlapping_matches() {
368        let patterns: Vec<&[u8]> = vec![b"ab", b"bc", b"abc"];
369        let ac = Automaton::build(&patterns);
370
371        let matches: Vec<_> = ac.find_overlapping_iter(b"abc").collect();
372        // Should find "ab", "abc", and "bc" (all overlapping)
373        assert!(matches.len() >= 2);
374    }
375}