provenant/license_detection/
automaton.rs1use crate::license_detection::models::RuleId;
12use daachorse::DoubleArrayAhoCorasick;
13use rancor::Fallible;
14use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
15use rkyv::{Archive, Deserialize, Place, Serialize};
16
17pub struct AsBytes;
19
20impl ArchiveWith<Automaton> for AsBytes {
21 type Archived = <Vec<u8> as Archive>::Archived;
22 type Resolver = <Vec<u8> as Archive>::Resolver;
23
24 fn resolve_with(field: &Automaton, resolver: Self::Resolver, out: Place<Self::Archived>) {
25 field.serialize_bytes().resolve(resolver, out);
26 }
27}
28
29impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized> SerializeWith<Automaton, S>
30 for AsBytes
31{
32 fn serialize_with(field: &Automaton, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
33 field.serialize_bytes().serialize(serializer)
34 }
35}
36
37impl<D: Fallible + ?Sized> DeserializeWith<<Vec<u8> as Archive>::Archived, Automaton, D> for AsBytes
38where
39 <Vec<u8> as Archive>::Archived: Deserialize<Vec<u8>, D>,
40{
41 fn deserialize_with(
42 field: &<Vec<u8> as Archive>::Archived,
43 deserializer: &mut D,
44 ) -> Result<Automaton, D::Error> {
45 let bytes: Vec<u8> = field.deserialize(deserializer)?;
46 Ok(Automaton::deserialize_unchecked(&bytes))
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct Match {
53 pub rule_id: RuleId,
55 pub start: usize,
57 pub end: usize,
59}
60
61pub struct Automaton {
66 inner: DoubleArrayAhoCorasick<u32>,
67}
68
69impl std::fmt::Debug for Automaton {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("Automaton")
72 .field("num_states", &self.inner.num_states())
73 .field("heap_bytes", &self.inner.heap_bytes())
74 .finish()
75 }
76}
77
78impl Clone for Automaton {
79 fn clone(&self) -> Self {
80 let bytes = self.inner.serialize();
81 Self::deserialize_unchecked(&bytes)
82 }
83}
84
85impl Automaton {
86 pub fn empty() -> Self {
91 let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
94 let inner = DoubleArrayAhoCorasick::new([dummy_pattern])
95 .expect("Failed to create empty automaton with hardcoded dummy pattern");
96 Self { inner }
97 }
98
99 pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
109 FindOverlappingIter::new(&self.inner, haystack)
110 }
111
112 pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
117 let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
118 Self { inner: ac }
119 }
120
121 pub fn num_states(&self) -> usize {
123 self.inner.num_states()
124 }
125
126 pub fn heap_bytes(&self) -> usize {
128 self.inner.heap_bytes()
129 }
130
131 pub fn serialize_bytes(&self) -> Vec<u8> {
133 self.inner.serialize()
134 }
135}
136
137impl Default for Automaton {
138 fn default() -> Self {
139 Self::empty()
140 }
141}
142
143pub struct FindOverlappingIter {
152 inner: std::vec::IntoIter<daachorse::Match<u32>>,
153}
154
155impl FindOverlappingIter {
156 fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
157 let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
158 Self {
159 inner: matches.into_iter(),
160 }
161 }
162}
163
164impl Iterator for FindOverlappingIter {
165 type Item = Match;
166
167 fn next(&mut self) -> Option<Self::Item> {
168 loop {
169 let m = self.inner.next()?;
170 if m.start() % 2 == 0 {
173 return Some(Match {
174 rule_id: RuleId::new(m.value() as usize),
175 start: m.start(),
176 end: m.end(),
177 });
178 }
179 }
181 }
182}
183
184pub struct AutomatonBuilder {
188 patterns: Vec<Vec<u8>>,
189 values: Vec<u32>,
190}
191
192impl AutomatonBuilder {
193 pub fn new() -> Self {
195 Self {
196 patterns: Vec::new(),
197 values: Vec::new(),
198 }
199 }
200
201 pub fn add_pattern_with_value(&mut self, pattern: &[u8], value: u32) {
206 if !pattern.is_empty() {
207 self.patterns.push(pattern.to_vec());
208 self.values.push(value);
209 }
210 }
211
212 pub fn add_pattern(&mut self, pattern: &[u8]) {
216 let value = self.patterns.len() as u32;
217 self.add_pattern_with_value(pattern, value);
218 }
219
220 pub fn build(self) -> Automaton {
225 if self.patterns.is_empty() {
226 return Automaton::empty();
227 }
228
229 let patvals: Vec<(&[u8], u32)> = self
230 .patterns
231 .iter()
232 .zip(self.values.iter())
233 .map(|(p, &v)| (p.as_slice(), v))
234 .collect();
235
236 match DoubleArrayAhoCorasick::with_values(patvals) {
237 Ok(ac) => Automaton { inner: ac },
238 Err(_) => Automaton::empty(),
239 }
240 }
241}
242
243impl Default for AutomatonBuilder {
244 fn default() -> Self {
245 Self::new()
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_token_boundary_filtering() {
255 let pattern: &[u8] = &[31, 49];
256 let mut builder = AutomatonBuilder::new();
257 builder.add_pattern(pattern);
258 let ac = builder.build();
259
260 let haystack: &[u8] = &[109, 31, 49, 74];
263 let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
264 assert!(
265 matches.is_empty(),
266 "Should not match across token boundaries"
267 );
268 }
269
270 #[test]
271 fn test_valid_token_match() {
272 let pattern: &[u8] = &[31, 49];
273 let mut builder = AutomatonBuilder::new();
274 builder.add_pattern(pattern);
275 let ac = builder.build();
276
277 let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
278 let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
279 assert_eq!(matches.len(), 1);
280 assert_eq!(matches[0].start, 2);
281 assert_eq!(matches[0].end, 4);
282 }
283
284 #[test]
285 fn test_builder_skips_empty_patterns() {
286 let mut builder = AutomatonBuilder::new();
287 builder.add_pattern(b"");
288 builder.add_pattern(b"hello");
289 builder.add_pattern(b"");
290 let ac = builder.build();
291
292 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
293 assert_eq!(matches.len(), 1);
294 }
295
296 #[test]
297 fn test_builder_with_values() {
298 let mut builder = AutomatonBuilder::new();
299 builder.add_pattern_with_value(b"hello", 42);
300 builder.add_pattern_with_value(b"world", 99);
301 let ac = builder.build();
302
303 let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
304 assert_eq!(matches.len(), 2);
305 assert_eq!(matches[0].rule_id, RuleId::new(42));
306 assert_eq!(matches[1].rule_id, RuleId::new(99));
307 }
308
309 #[test]
310 fn test_builder_duplicate_patterns() {
311 let mut builder = AutomatonBuilder::new();
312 builder.add_pattern_with_value(b"hello", 10);
313 builder.add_pattern_with_value(b"hello", 20);
314 let ac = builder.build();
315
316 let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
317 assert_eq!(matches.len(), 2);
318 let mut values: Vec<RuleId> = matches.iter().map(|m| m.rule_id).collect();
319 values.sort();
320 assert_eq!(values, vec![RuleId::new(10), RuleId::new(20)]);
321 }
322}