Skip to main content

provenant/license_detection/
automaton.rs

1// SPDX-FileCopyrightText: Provenant contributors
2// SPDX-License-Identifier: Apache-2.0
3
4//! Aho-Corasick automaton wrapper using daachorse.
5//!
6//! This module provides a `DoubleArrayAhoCorasick`-based automaton that is
7//! significantly smaller than the aho-corasick crate's implementation.
8//! The daachorse library provides ~85% smaller binary size and built-in
9//! serialization support.
10
11use crate::license_detection::models::RuleId;
12use daachorse::DoubleArrayAhoCorasick;
13use rancor::Fallible;
14use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
15use rkyv::{Archive, Deserialize, Place, Serialize};
16
17/// rkyv `with` adapter that archives an `Automaton` as its serialized byte form.
18pub struct AsBytes;
19
20impl ArchiveWith<Automaton> for AsBytes {
21    type Archived = <Vec<u8> as Archive>::Archived;
22    type Resolver = <Vec<u8> as Archive>::Resolver;
23
24    fn resolve_with(field: &Automaton, resolver: Self::Resolver, out: Place<Self::Archived>) {
25        field.serialize_bytes().resolve(resolver, out);
26    }
27}
28
29impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized> SerializeWith<Automaton, S>
30    for AsBytes
31{
32    fn serialize_with(field: &Automaton, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
33        field.serialize_bytes().serialize(serializer)
34    }
35}
36
37impl<D: Fallible + ?Sized> DeserializeWith<<Vec<u8> as Archive>::Archived, Automaton, D> for AsBytes
38where
39    <Vec<u8> as Archive>::Archived: Deserialize<Vec<u8>, D>,
40{
41    fn deserialize_with(
42        field: &<Vec<u8> as Archive>::Archived,
43        deserializer: &mut D,
44    ) -> Result<Automaton, D::Error> {
45        let bytes: Vec<u8> = field.deserialize(deserializer)?;
46        Ok(Automaton::deserialize_unchecked(&bytes))
47    }
48}
49
50/// A match found by the automaton.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct Match {
53    /// The rule ID associated with the matched pattern.
54    pub rule_id: RuleId,
55    /// Start position in haystack (bytes, inclusive).
56    pub start: usize,
57    /// End position in haystack (bytes, exclusive).
58    pub end: usize,
59}
60
61/// Aho-Corasick automaton using daachorse's double-array implementation.
62///
63/// This wrapper provides the same interface as the previous FrozenNfa
64/// but with significantly smaller memory footprint and serialization support.
65pub struct Automaton {
66    inner: DoubleArrayAhoCorasick<u32>,
67}
68
69impl std::fmt::Debug for Automaton {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("Automaton")
72            .field("num_states", &self.inner.num_states())
73            .field("heap_bytes", &self.inner.heap_bytes())
74            .finish()
75    }
76}
77
78impl Clone for Automaton {
79    fn clone(&self) -> Self {
80        let bytes = self.inner.serialize();
81        Self::deserialize_unchecked(&bytes)
82    }
83}
84
85impl Automaton {
86    /// Create a new empty automaton.
87    ///
88    /// Since daachorse requires at least one non-empty pattern, we use a
89    /// dummy pattern that will never match in practice (a unique byte sequence).
90    pub fn empty() -> Self {
91        // Use a very unlikely byte sequence as a sentinel pattern
92        // This will match but never in our token-encoded data
93        let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
94        let inner = DoubleArrayAhoCorasick::new([dummy_pattern])
95            .expect("Failed to create empty automaton with hardcoded dummy pattern");
96        Self { inner }
97    }
98
99    /// Find all overlapping matches in the haystack.
100    ///
101    /// Returns an iterator that yields all matches found in the haystack,
102    /// including overlapping matches. The matches are yielded in order of
103    /// their end position.
104    ///
105    /// **Important**: This filters matches to only those starting at even
106    /// byte positions (token boundaries). Each token is encoded as 2 bytes,
107    /// so matches starting at odd byte positions would span token boundaries.
108    pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
109        FindOverlappingIter::new(&self.inner, haystack)
110    }
111
112    /// Deserialize an automaton from bytes.
113    ///
114    /// # Safety
115    /// The bytes must be valid serialized data from the underlying daachorse automaton.
116    pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
117        let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
118        Self { inner: ac }
119    }
120
121    /// Get the number of states in the automaton.
122    pub fn num_states(&self) -> usize {
123        self.inner.num_states()
124    }
125
126    /// Get the memory usage in bytes.
127    pub fn heap_bytes(&self) -> usize {
128        self.inner.heap_bytes()
129    }
130
131    /// Serialize the automaton to a byte vector.
132    pub fn serialize_bytes(&self) -> Vec<u8> {
133        self.inner.serialize()
134    }
135}
136
137impl Default for Automaton {
138    fn default() -> Self {
139        Self::empty()
140    }
141}
142
143/// Iterator over all overlapping matches in a haystack.
144///
145/// This iterator finds all matches, including those that overlap, by
146/// continuing to search after each match rather than skipping past it.
147///
148/// **Token Boundary Filtering**: This iterator only yields matches that
149/// start at even byte positions. Since each token is encoded as 2 bytes,
150/// matches at odd positions would incorrectly span token boundaries.
151pub struct FindOverlappingIter {
152    inner: std::vec::IntoIter<daachorse::Match<u32>>,
153}
154
155impl FindOverlappingIter {
156    fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
157        let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
158        Self {
159            inner: matches.into_iter(),
160        }
161    }
162}
163
164impl Iterator for FindOverlappingIter {
165    type Item = Match;
166
167    fn next(&mut self) -> Option<Self::Item> {
168        loop {
169            let m = self.inner.next()?;
170            // Token boundary check: each token is 2 bytes, so matches must
171            // start at even byte positions. Odd positions would span tokens.
172            if m.start() % 2 == 0 {
173                return Some(Match {
174                    rule_id: RuleId::new(m.value() as usize),
175                    start: m.start(),
176                    end: m.end(),
177                });
178            }
179            // Skip matches at odd byte positions (invalid token boundaries)
180        }
181    }
182}
183
184/// Builder for constructing automatons incrementally.
185///
186/// This mirrors the `FrozenNfaBuilder` interface for compatibility.
187pub struct AutomatonBuilder {
188    patterns: Vec<Vec<u8>>,
189    values: Vec<u32>,
190}
191
192impl AutomatonBuilder {
193    /// Create a new builder.
194    pub fn new() -> Self {
195        Self {
196            patterns: Vec::new(),
197            values: Vec::new(),
198        }
199    }
200
201    /// Add a pattern to the automaton with an associated value.
202    ///
203    /// Empty patterns are skipped. Daachorse 3.0+ supports duplicate patterns;
204    /// each occurrence gets its own value.
205    pub fn add_pattern_with_value(&mut self, pattern: &[u8], value: u32) {
206        if !pattern.is_empty() {
207            self.patterns.push(pattern.to_vec());
208            self.values.push(value);
209        }
210    }
211
212    /// Add a pattern to the automaton.
213    ///
214    /// Empty patterns are skipped. Assigns sequential IDs (0, 1, 2, ...).
215    pub fn add_pattern(&mut self, pattern: &[u8]) {
216        let value = self.patterns.len() as u32;
217        self.add_pattern_with_value(pattern, value);
218    }
219
220    /// Build the automaton.
221    ///
222    /// Uses `with_values()` so each pattern's value is directly accessible
223    /// via `Match::value()`, eliminating the need for an external pattern_id-to-rid mapping.
224    pub fn build(self) -> Automaton {
225        if self.patterns.is_empty() {
226            return Automaton::empty();
227        }
228
229        let patvals: Vec<(&[u8], u32)> = self
230            .patterns
231            .iter()
232            .zip(self.values.iter())
233            .map(|(p, &v)| (p.as_slice(), v))
234            .collect();
235
236        match DoubleArrayAhoCorasick::with_values(patvals) {
237            Ok(ac) => Automaton { inner: ac },
238            Err(_) => Automaton::empty(),
239        }
240    }
241}
242
243impl Default for AutomatonBuilder {
244    fn default() -> Self {
245        Self::new()
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_token_boundary_filtering() {
255        let pattern: &[u8] = &[31, 49];
256        let mut builder = AutomatonBuilder::new();
257        builder.add_pattern(pattern);
258        let ac = builder.build();
259
260        // The pattern [31, 49] appears at bytes 1-2 (odd position)
261        // which would span token boundaries - should NOT match
262        let haystack: &[u8] = &[109, 31, 49, 74];
263        let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
264        assert!(
265            matches.is_empty(),
266            "Should not match across token boundaries"
267        );
268    }
269
270    #[test]
271    fn test_valid_token_match() {
272        let pattern: &[u8] = &[31, 49];
273        let mut builder = AutomatonBuilder::new();
274        builder.add_pattern(pattern);
275        let ac = builder.build();
276
277        let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
278        let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
279        assert_eq!(matches.len(), 1);
280        assert_eq!(matches[0].start, 2);
281        assert_eq!(matches[0].end, 4);
282    }
283
284    #[test]
285    fn test_builder_skips_empty_patterns() {
286        let mut builder = AutomatonBuilder::new();
287        builder.add_pattern(b"");
288        builder.add_pattern(b"hello");
289        builder.add_pattern(b"");
290        let ac = builder.build();
291
292        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
293        assert_eq!(matches.len(), 1);
294    }
295
296    #[test]
297    fn test_builder_with_values() {
298        let mut builder = AutomatonBuilder::new();
299        builder.add_pattern_with_value(b"hello", 42);
300        builder.add_pattern_with_value(b"world", 99);
301        let ac = builder.build();
302
303        let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
304        assert_eq!(matches.len(), 2);
305        assert_eq!(matches[0].rule_id, RuleId::new(42));
306        assert_eq!(matches[1].rule_id, RuleId::new(99));
307    }
308
309    #[test]
310    fn test_builder_duplicate_patterns() {
311        let mut builder = AutomatonBuilder::new();
312        builder.add_pattern_with_value(b"hello", 10);
313        builder.add_pattern_with_value(b"hello", 20);
314        let ac = builder.build();
315
316        let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
317        assert_eq!(matches.len(), 2);
318        let mut values: Vec<RuleId> = matches.iter().map(|m| m.rule_id).collect();
319        values.sort();
320        assert_eq!(values, vec![RuleId::new(10), RuleId::new(20)]);
321    }
322}