provenant/license_detection/
automaton.rs1use crate::license_detection::models::RuleId;
12use daachorse::DoubleArrayAhoCorasick;
13use rancor::Fallible;
14use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
15use rkyv::{Archive, Deserialize, Place, Serialize};
16
17pub struct AsBytes;
19
20impl ArchiveWith<Automaton> for AsBytes {
21 type Archived = <Vec<u8> as Archive>::Archived;
22 type Resolver = <Vec<u8> as Archive>::Resolver;
23
24 fn resolve_with(field: &Automaton, resolver: Self::Resolver, out: Place<Self::Archived>) {
25 field.serialize_bytes().resolve(resolver, out);
26 }
27}
28
29impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized> SerializeWith<Automaton, S>
30 for AsBytes
31{
32 fn serialize_with(field: &Automaton, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
33 field.serialize_bytes().serialize(serializer)
34 }
35}
36
37impl<D: Fallible + ?Sized> DeserializeWith<<Vec<u8> as Archive>::Archived, Automaton, D> for AsBytes
38where
39 <Vec<u8> as Archive>::Archived: Deserialize<Vec<u8>, D>,
40{
41 fn deserialize_with(
42 field: &<Vec<u8> as Archive>::Archived,
43 deserializer: &mut D,
44 ) -> Result<Automaton, D::Error> {
45 let bytes: Vec<u8> = field.deserialize(deserializer)?;
46 Ok(Automaton::deserialize_unchecked(&bytes))
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct Match {
53 pub rule_id: RuleId,
55 pub start: usize,
57 pub end: usize,
59}
60
61pub struct Automaton {
66 inner: DoubleArrayAhoCorasick<u32>,
67}
68
69impl std::fmt::Debug for Automaton {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("Automaton")
72 .field("num_states", &self.inner.num_states())
73 .field("heap_bytes", &self.inner.heap_bytes())
74 .finish()
75 }
76}
77
78impl Clone for Automaton {
79 fn clone(&self) -> Self {
80 let bytes = self.inner.serialize();
81 Self::deserialize_unchecked(&bytes)
82 }
83}
84
85impl Automaton {
86 pub fn empty() -> Self {
91 let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
94 match DoubleArrayAhoCorasick::new([dummy_pattern]) {
95 Ok(ac) => Self { inner: ac },
96 Err(_) => panic!("Failed to create empty automaton"),
97 }
98 }
99
100 pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
110 FindOverlappingIter::new(&self.inner, haystack)
111 }
112
113 pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
118 let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
119 Self { inner: ac }
120 }
121
122 pub fn num_states(&self) -> usize {
124 self.inner.num_states()
125 }
126
127 pub fn heap_bytes(&self) -> usize {
129 self.inner.heap_bytes()
130 }
131
132 pub fn serialize_bytes(&self) -> Vec<u8> {
134 self.inner.serialize()
135 }
136}
137
138impl Default for Automaton {
139 fn default() -> Self {
140 Self::empty()
141 }
142}
143
144pub struct FindOverlappingIter {
153 inner: std::vec::IntoIter<daachorse::Match<u32>>,
154}
155
156impl FindOverlappingIter {
157 fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
158 let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
159 Self {
160 inner: matches.into_iter(),
161 }
162 }
163}
164
165impl Iterator for FindOverlappingIter {
166 type Item = Match;
167
168 fn next(&mut self) -> Option<Self::Item> {
169 loop {
170 let m = self.inner.next()?;
171 if m.start() % 2 == 0 {
174 return Some(Match {
175 rule_id: RuleId::new(m.value() as usize),
176 start: m.start(),
177 end: m.end(),
178 });
179 }
180 }
182 }
183}
184
185pub struct AutomatonBuilder {
189 patterns: Vec<Vec<u8>>,
190 values: Vec<u32>,
191}
192
193impl AutomatonBuilder {
194 pub fn new() -> Self {
196 Self {
197 patterns: Vec::new(),
198 values: Vec::new(),
199 }
200 }
201
202 pub fn add_pattern_with_value(&mut self, pattern: &[u8], value: u32) {
207 if !pattern.is_empty() {
208 self.patterns.push(pattern.to_vec());
209 self.values.push(value);
210 }
211 }
212
213 pub fn add_pattern(&mut self, pattern: &[u8]) {
217 let value = self.patterns.len() as u32;
218 self.add_pattern_with_value(pattern, value);
219 }
220
221 pub fn build(self) -> Automaton {
226 if self.patterns.is_empty() {
227 return Automaton::empty();
228 }
229
230 let patvals: Vec<(&[u8], u32)> = self
231 .patterns
232 .iter()
233 .zip(self.values.iter())
234 .map(|(p, &v)| (p.as_slice(), v))
235 .collect();
236
237 match DoubleArrayAhoCorasick::with_values(patvals) {
238 Ok(ac) => Automaton { inner: ac },
239 Err(_) => Automaton::empty(),
240 }
241 }
242}
243
244impl Default for AutomatonBuilder {
245 fn default() -> Self {
246 Self::new()
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_token_boundary_filtering() {
256 let pattern: &[u8] = &[31, 49];
257 let mut builder = AutomatonBuilder::new();
258 builder.add_pattern(pattern);
259 let ac = builder.build();
260
261 let haystack: &[u8] = &[109, 31, 49, 74];
264 let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
265 assert!(
266 matches.is_empty(),
267 "Should not match across token boundaries"
268 );
269 }
270
271 #[test]
272 fn test_valid_token_match() {
273 let pattern: &[u8] = &[31, 49];
274 let mut builder = AutomatonBuilder::new();
275 builder.add_pattern(pattern);
276 let ac = builder.build();
277
278 let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
279 let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
280 assert_eq!(matches.len(), 1);
281 assert_eq!(matches[0].start, 2);
282 assert_eq!(matches[0].end, 4);
283 }
284
285 #[test]
286 fn test_builder_skips_empty_patterns() {
287 let mut builder = AutomatonBuilder::new();
288 builder.add_pattern(b"");
289 builder.add_pattern(b"hello");
290 builder.add_pattern(b"");
291 let ac = builder.build();
292
293 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
294 assert_eq!(matches.len(), 1);
295 }
296
297 #[test]
298 fn test_builder_with_values() {
299 let mut builder = AutomatonBuilder::new();
300 builder.add_pattern_with_value(b"hello", 42);
301 builder.add_pattern_with_value(b"world", 99);
302 let ac = builder.build();
303
304 let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
305 assert_eq!(matches.len(), 2);
306 assert_eq!(matches[0].rule_id, RuleId::new(42));
307 assert_eq!(matches[1].rule_id, RuleId::new(99));
308 }
309
310 #[test]
311 fn test_builder_duplicate_patterns() {
312 let mut builder = AutomatonBuilder::new();
313 builder.add_pattern_with_value(b"hello", 10);
314 builder.add_pattern_with_value(b"hello", 20);
315 let ac = builder.build();
316
317 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
318 assert_eq!(matches.len(), 2);
319 let mut values: Vec<RuleId> = matches.iter().map(|m| m.rule_id).collect();
320 values.sort();
321 assert_eq!(values, vec![RuleId::new(10), RuleId::new(20)]);
322 }
323}