use crate::license_detection::models::RuleId;
use daachorse::DoubleArrayAhoCorasick;
use rancor::Fallible;
use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
use rkyv::{Archive, Deserialize, Place, Serialize};
pub struct AsBytes;
impl ArchiveWith<Automaton> for AsBytes {
type Archived = <Vec<u8> as Archive>::Archived;
type Resolver = <Vec<u8> as Archive>::Resolver;
fn resolve_with(field: &Automaton, resolver: Self::Resolver, out: Place<Self::Archived>) {
field.serialize_bytes().resolve(resolver, out);
}
}
impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized> SerializeWith<Automaton, S>
for AsBytes
{
fn serialize_with(field: &Automaton, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
field.serialize_bytes().serialize(serializer)
}
}
impl<D: Fallible + ?Sized> DeserializeWith<<Vec<u8> as Archive>::Archived, Automaton, D> for AsBytes
where
<Vec<u8> as Archive>::Archived: Deserialize<Vec<u8>, D>,
{
fn deserialize_with(
field: &<Vec<u8> as Archive>::Archived,
deserializer: &mut D,
) -> Result<Automaton, D::Error> {
let bytes: Vec<u8> = field.deserialize(deserializer)?;
Ok(Automaton::deserialize_unchecked(&bytes))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Match {
pub rule_id: RuleId,
pub start: usize,
pub end: usize,
}
pub struct Automaton {
inner: DoubleArrayAhoCorasick<u32>,
}
impl std::fmt::Debug for Automaton {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Automaton")
.field("num_states", &self.inner.num_states())
.field("heap_bytes", &self.inner.heap_bytes())
.finish()
}
}
impl Clone for Automaton {
fn clone(&self) -> Self {
let bytes = self.inner.serialize();
Self::deserialize_unchecked(&bytes)
}
}
impl Automaton {
pub fn empty() -> Self {
let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
match DoubleArrayAhoCorasick::new([dummy_pattern]) {
Ok(ac) => Self { inner: ac },
Err(_) => panic!("Failed to create empty automaton"),
}
}
pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
FindOverlappingIter::new(&self.inner, haystack)
}
pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
Self { inner: ac }
}
pub fn num_states(&self) -> usize {
self.inner.num_states()
}
pub fn heap_bytes(&self) -> usize {
self.inner.heap_bytes()
}
pub fn serialize_bytes(&self) -> Vec<u8> {
self.inner.serialize()
}
}
impl Default for Automaton {
fn default() -> Self {
Self::empty()
}
}
pub struct FindOverlappingIter {
inner: std::vec::IntoIter<daachorse::Match<u32>>,
}
impl FindOverlappingIter {
fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
Self {
inner: matches.into_iter(),
}
}
}
impl Iterator for FindOverlappingIter {
type Item = Match;
fn next(&mut self) -> Option<Self::Item> {
loop {
let m = self.inner.next()?;
if m.start() % 2 == 0 {
return Some(Match {
rule_id: RuleId::new(m.value() as usize),
start: m.start(),
end: m.end(),
});
}
}
}
}
pub struct AutomatonBuilder {
patterns: Vec<Vec<u8>>,
values: Vec<u32>,
}
impl AutomatonBuilder {
pub fn new() -> Self {
Self {
patterns: Vec::new(),
values: Vec::new(),
}
}
pub fn add_pattern_with_value(&mut self, pattern: &[u8], value: u32) {
if !pattern.is_empty() {
self.patterns.push(pattern.to_vec());
self.values.push(value);
}
}
pub fn add_pattern(&mut self, pattern: &[u8]) {
let value = self.patterns.len() as u32;
self.add_pattern_with_value(pattern, value);
}
pub fn build(self) -> Automaton {
if self.patterns.is_empty() {
return Automaton::empty();
}
let patvals: Vec<(&[u8], u32)> = self
.patterns
.iter()
.zip(self.values.iter())
.map(|(p, &v)| (p.as_slice(), v))
.collect();
match DoubleArrayAhoCorasick::with_values(patvals) {
Ok(ac) => Automaton { inner: ac },
Err(_) => Automaton::empty(),
}
}
}
impl Default for AutomatonBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_boundary_filtering() {
let pattern: &[u8] = &[31, 49];
let mut builder = AutomatonBuilder::new();
builder.add_pattern(pattern);
let ac = builder.build();
let haystack: &[u8] = &[109, 31, 49, 74];
let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
assert!(
matches.is_empty(),
"Should not match across token boundaries"
);
}
#[test]
fn test_valid_token_match() {
let pattern: &[u8] = &[31, 49];
let mut builder = AutomatonBuilder::new();
builder.add_pattern(pattern);
let ac = builder.build();
let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].start, 2);
assert_eq!(matches[0].end, 4);
}
#[test]
fn test_builder_skips_empty_patterns() {
let mut builder = AutomatonBuilder::new();
builder.add_pattern(b"");
builder.add_pattern(b"hello");
builder.add_pattern(b"");
let ac = builder.build();
let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
assert_eq!(matches.len(), 1);
}
#[test]
fn test_builder_with_values() {
let mut builder = AutomatonBuilder::new();
builder.add_pattern_with_value(b"hello", 42);
builder.add_pattern_with_value(b"world", 99);
let ac = builder.build();
let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].rule_id, RuleId::new(42));
assert_eq!(matches[1].rule_id, RuleId::new(99));
}
#[test]
fn test_builder_duplicate_patterns() {
let mut builder = AutomatonBuilder::new();
builder.add_pattern_with_value(b"hello", 10);
builder.add_pattern_with_value(b"hello", 20);
let ac = builder.build();
let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
assert_eq!(matches.len(), 2);
let mut values: Vec<RuleId> = matches.iter().map(|m| m.rule_id).collect();
values.sort();
assert_eq!(values, vec![RuleId::new(10), RuleId::new(20)]);
}
}