Skip to main content

provenant/license_detection/
automaton.rs

1// SPDX-FileCopyrightText: Provenant contributors
2// SPDX-License-Identifier: Apache-2.0
3
4//! Aho-Corasick automaton wrapper using daachorse.
5//!
6//! This module provides a `DoubleArrayAhoCorasick`-based automaton that is
7//! significantly smaller than the aho-corasick crate's implementation.
8//! The daachorse library provides ~85% smaller binary size and built-in
9//! serialization support.
10
11use crate::license_detection::models::RuleId;
12use daachorse::DoubleArrayAhoCorasick;
13use rancor::Fallible;
14use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
15use rkyv::{Archive, Deserialize, Place, Serialize};
16
17/// rkyv `with` adapter that archives an `Automaton` as its serialized byte form.
18pub struct AsBytes;
19
20impl ArchiveWith<Automaton> for AsBytes {
21    type Archived = <Vec<u8> as Archive>::Archived;
22    type Resolver = <Vec<u8> as Archive>::Resolver;
23
24    fn resolve_with(field: &Automaton, resolver: Self::Resolver, out: Place<Self::Archived>) {
25        field.serialize_bytes().resolve(resolver, out);
26    }
27}
28
29impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized> SerializeWith<Automaton, S>
30    for AsBytes
31{
32    fn serialize_with(field: &Automaton, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
33        field.serialize_bytes().serialize(serializer)
34    }
35}
36
37impl<D: Fallible + ?Sized> DeserializeWith<<Vec<u8> as Archive>::Archived, Automaton, D> for AsBytes
38where
39    <Vec<u8> as Archive>::Archived: Deserialize<Vec<u8>, D>,
40{
41    fn deserialize_with(
42        field: &<Vec<u8> as Archive>::Archived,
43        deserializer: &mut D,
44    ) -> Result<Automaton, D::Error> {
45        let bytes: Vec<u8> = field.deserialize(deserializer)?;
46        Ok(Automaton::deserialize_unchecked(&bytes))
47    }
48}
49
50/// A match found by the automaton.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct Match {
53    /// The rule ID associated with the matched pattern.
54    pub rule_id: RuleId,
55    /// Start position in haystack (bytes, inclusive).
56    pub start: usize,
57    /// End position in haystack (bytes, exclusive).
58    pub end: usize,
59}
60
61/// Aho-Corasick automaton using daachorse's double-array implementation.
62///
63/// This wrapper provides the same interface as the previous FrozenNfa
64/// but with significantly smaller memory footprint and serialization support.
65pub struct Automaton {
66    inner: DoubleArrayAhoCorasick<u32>,
67}
68
69impl std::fmt::Debug for Automaton {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("Automaton")
72            .field("num_states", &self.inner.num_states())
73            .field("heap_bytes", &self.inner.heap_bytes())
74            .finish()
75    }
76}
77
78impl Clone for Automaton {
79    fn clone(&self) -> Self {
80        let bytes = self.inner.serialize();
81        Self::deserialize_unchecked(&bytes)
82    }
83}
84
85impl Automaton {
86    /// Create a new empty automaton.
87    ///
88    /// Since daachorse requires at least one non-empty pattern, we use a
89    /// dummy pattern that will never match in practice (a unique byte sequence).
90    pub fn empty() -> Self {
91        // Use a very unlikely byte sequence as a sentinel pattern
92        // This will match but never in our token-encoded data
93        let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
94        match DoubleArrayAhoCorasick::new([dummy_pattern]) {
95            Ok(ac) => Self { inner: ac },
96            Err(_) => panic!("Failed to create empty automaton"),
97        }
98    }
99
100    /// Find all overlapping matches in the haystack.
101    ///
102    /// Returns an iterator that yields all matches found in the haystack,
103    /// including overlapping matches. The matches are yielded in order of
104    /// their end position.
105    ///
106    /// **Important**: This filters matches to only those starting at even
107    /// byte positions (token boundaries). Each token is encoded as 2 bytes,
108    /// so matches starting at odd byte positions would span token boundaries.
109    pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
110        FindOverlappingIter::new(&self.inner, haystack)
111    }
112
113    /// Deserialize an automaton from bytes.
114    ///
115    /// # Safety
116    /// The bytes must be valid serialized data from the underlying daachorse automaton.
117    pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
118        let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
119        Self { inner: ac }
120    }
121
122    /// Get the number of states in the automaton.
123    pub fn num_states(&self) -> usize {
124        self.inner.num_states()
125    }
126
127    /// Get the memory usage in bytes.
128    pub fn heap_bytes(&self) -> usize {
129        self.inner.heap_bytes()
130    }
131
132    /// Serialize the automaton to a byte vector.
133    pub fn serialize_bytes(&self) -> Vec<u8> {
134        self.inner.serialize()
135    }
136}
137
138impl Default for Automaton {
139    fn default() -> Self {
140        Self::empty()
141    }
142}
143
144/// Iterator over all overlapping matches in a haystack.
145///
146/// This iterator finds all matches, including those that overlap, by
147/// continuing to search after each match rather than skipping past it.
148///
149/// **Token Boundary Filtering**: This iterator only yields matches that
150/// start at even byte positions. Since each token is encoded as 2 bytes,
151/// matches at odd positions would incorrectly span token boundaries.
152pub struct FindOverlappingIter {
153    inner: std::vec::IntoIter<daachorse::Match<u32>>,
154}
155
156impl FindOverlappingIter {
157    fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
158        let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
159        Self {
160            inner: matches.into_iter(),
161        }
162    }
163}
164
165impl Iterator for FindOverlappingIter {
166    type Item = Match;
167
168    fn next(&mut self) -> Option<Self::Item> {
169        loop {
170            let m = self.inner.next()?;
171            // Token boundary check: each token is 2 bytes, so matches must
172            // start at even byte positions. Odd positions would span tokens.
173            if m.start() % 2 == 0 {
174                return Some(Match {
175                    rule_id: RuleId::new(m.value() as usize),
176                    start: m.start(),
177                    end: m.end(),
178                });
179            }
180            // Skip matches at odd byte positions (invalid token boundaries)
181        }
182    }
183}
184
185/// Builder for constructing automatons incrementally.
186///
187/// This mirrors the `FrozenNfaBuilder` interface for compatibility.
188pub struct AutomatonBuilder {
189    patterns: Vec<Vec<u8>>,
190    values: Vec<u32>,
191}
192
193impl AutomatonBuilder {
194    /// Create a new builder.
195    pub fn new() -> Self {
196        Self {
197            patterns: Vec::new(),
198            values: Vec::new(),
199        }
200    }
201
202    /// Add a pattern to the automaton with an associated value.
203    ///
204    /// Empty patterns are skipped. Daachorse 3.0+ supports duplicate patterns;
205    /// each occurrence gets its own value.
206    pub fn add_pattern_with_value(&mut self, pattern: &[u8], value: u32) {
207        if !pattern.is_empty() {
208            self.patterns.push(pattern.to_vec());
209            self.values.push(value);
210        }
211    }
212
213    /// Add a pattern to the automaton.
214    ///
215    /// Empty patterns are skipped. Assigns sequential IDs (0, 1, 2, ...).
216    pub fn add_pattern(&mut self, pattern: &[u8]) {
217        let value = self.patterns.len() as u32;
218        self.add_pattern_with_value(pattern, value);
219    }
220
221    /// Build the automaton.
222    ///
223    /// Uses `with_values()` so each pattern's value is directly accessible
224    /// via `Match::value()`, eliminating the need for an external pattern_id-to-rid mapping.
225    pub fn build(self) -> Automaton {
226        if self.patterns.is_empty() {
227            return Automaton::empty();
228        }
229
230        let patvals: Vec<(&[u8], u32)> = self
231            .patterns
232            .iter()
233            .zip(self.values.iter())
234            .map(|(p, &v)| (p.as_slice(), v))
235            .collect();
236
237        match DoubleArrayAhoCorasick::with_values(patvals) {
238            Ok(ac) => Automaton { inner: ac },
239            Err(_) => Automaton::empty(),
240        }
241    }
242}
243
244impl Default for AutomatonBuilder {
245    fn default() -> Self {
246        Self::new()
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_token_boundary_filtering() {
256        let pattern: &[u8] = &[31, 49];
257        let mut builder = AutomatonBuilder::new();
258        builder.add_pattern(pattern);
259        let ac = builder.build();
260
261        // The pattern [31, 49] appears at bytes 1-2 (odd position)
262        // which would span token boundaries - should NOT match
263        let haystack: &[u8] = &[109, 31, 49, 74];
264        let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
265        assert!(
266            matches.is_empty(),
267            "Should not match across token boundaries"
268        );
269    }
270
271    #[test]
272    fn test_valid_token_match() {
273        let pattern: &[u8] = &[31, 49];
274        let mut builder = AutomatonBuilder::new();
275        builder.add_pattern(pattern);
276        let ac = builder.build();
277
278        let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
279        let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
280        assert_eq!(matches.len(), 1);
281        assert_eq!(matches[0].start, 2);
282        assert_eq!(matches[0].end, 4);
283    }
284
285    #[test]
286    fn test_builder_skips_empty_patterns() {
287        let mut builder = AutomatonBuilder::new();
288        builder.add_pattern(b"");
289        builder.add_pattern(b"hello");
290        builder.add_pattern(b"");
291        let ac = builder.build();
292
293        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
294        assert_eq!(matches.len(), 1);
295    }
296
297    #[test]
298    fn test_builder_with_values() {
299        let mut builder = AutomatonBuilder::new();
300        builder.add_pattern_with_value(b"hello", 42);
301        builder.add_pattern_with_value(b"world", 99);
302        let ac = builder.build();
303
304        let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
305        assert_eq!(matches.len(), 2);
306        assert_eq!(matches[0].rule_id, RuleId::new(42));
307        assert_eq!(matches[1].rule_id, RuleId::new(99));
308    }
309
310    #[test]
311    fn test_builder_duplicate_patterns() {
312        let mut builder = AutomatonBuilder::new();
313        builder.add_pattern_with_value(b"hello", 10);
314        builder.add_pattern_with_value(b"hello", 20);
315        let ac = builder.build();
316
317        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
318        assert_eq!(matches.len(), 2);
319        let mut values: Vec<RuleId> = matches.iter().map(|m| m.rule_id).collect();
320        values.sort();
321        assert_eq!(values, vec![RuleId::new(10), RuleId::new(20)]);
322    }
323}