use daachorse::DoubleArrayAhoCorasick;
use rancor::Fallible;
use rkyv::with::{ArchiveWith, DeserializeWith, SerializeWith};
use rkyv::{Archive, Deserialize, Place, Serialize};
pub struct AsBytes;
impl ArchiveWith<Automaton> for AsBytes {
type Archived = <Vec<u8> as Archive>::Archived;
type Resolver = <Vec<u8> as Archive>::Resolver;
fn resolve_with(field: &Automaton, resolver: Self::Resolver, out: Place<Self::Archived>) {
field.serialize_bytes().resolve(resolver, out);
}
}
impl<S: Fallible + rkyv::ser::Writer + rkyv::ser::Allocator + ?Sized> SerializeWith<Automaton, S>
for AsBytes
{
fn serialize_with(field: &Automaton, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
field.serialize_bytes().serialize(serializer)
}
}
impl<D: Fallible + ?Sized> DeserializeWith<<Vec<u8> as Archive>::Archived, Automaton, D> for AsBytes
where
<Vec<u8> as Archive>::Archived: Deserialize<Vec<u8>, D>,
{
fn deserialize_with(
field: &<Vec<u8> as Archive>::Archived,
deserializer: &mut D,
) -> Result<Automaton, D::Error> {
let bytes: Vec<u8> = field.deserialize(deserializer)?;
Ok(Automaton::deserialize_unchecked(&bytes))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Match {
pub pattern: usize,
pub start: usize,
pub end: usize,
}
pub struct Automaton {
inner: DoubleArrayAhoCorasick<u32>,
}
impl std::fmt::Debug for Automaton {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Automaton")
.field("num_states", &self.inner.num_states())
.field("heap_bytes", &self.inner.heap_bytes())
.finish()
}
}
impl Clone for Automaton {
fn clone(&self) -> Self {
let bytes = self.inner.serialize();
Self::deserialize_unchecked(&bytes)
}
}
impl Automaton {
pub fn empty() -> Self {
let dummy_pattern: &[u8] = &[0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
match DoubleArrayAhoCorasick::new([dummy_pattern]) {
Ok(ac) => Self { inner: ac },
Err(_) => panic!("Failed to create empty automaton"),
}
}
#[allow(dead_code)]
pub fn build(patterns: &[&[u8]]) -> Self {
if patterns.is_empty() {
return Self::empty();
}
let non_empty: Vec<&[u8]> = patterns.iter().copied().filter(|p| !p.is_empty()).collect();
if non_empty.is_empty() {
return Self::empty();
}
match DoubleArrayAhoCorasick::new(non_empty) {
Ok(ac) => Self { inner: ac },
Err(_) => Self::empty(),
}
}
pub fn find_overlapping_iter(&self, haystack: &[u8]) -> FindOverlappingIter {
FindOverlappingIter::new(&self.inner, haystack)
}
pub fn deserialize_unchecked(bytes: &[u8]) -> Self {
let (ac, _) = unsafe { DoubleArrayAhoCorasick::deserialize_unchecked(bytes) };
Self { inner: ac }
}
#[allow(dead_code)]
pub fn num_states(&self) -> usize {
self.inner.num_states()
}
#[allow(dead_code)]
pub fn heap_bytes(&self) -> usize {
self.inner.heap_bytes()
}
pub fn serialize_bytes(&self) -> Vec<u8> {
self.inner.serialize()
}
}
impl Default for Automaton {
fn default() -> Self {
Self::empty()
}
}
pub struct FindOverlappingIter {
inner: std::vec::IntoIter<daachorse::Match<u32>>,
}
impl FindOverlappingIter {
fn new(automaton: &DoubleArrayAhoCorasick<u32>, haystack: &[u8]) -> Self {
let matches: Vec<_> = automaton.find_overlapping_iter(haystack).collect();
Self {
inner: matches.into_iter(),
}
}
}
impl Iterator for FindOverlappingIter {
type Item = Match;
fn next(&mut self) -> Option<Self::Item> {
loop {
let m = self.inner.next()?;
if m.start() % 2 == 0 {
return Some(Match {
pattern: m.value() as usize,
start: m.start(),
end: m.end(),
});
}
}
}
}
pub struct AutomatonBuilder {
patterns: Vec<Vec<u8>>,
}
impl AutomatonBuilder {
pub fn new() -> Self {
Self {
patterns: Vec::new(),
}
}
pub fn add_pattern(&mut self, pattern: &[u8]) {
if !pattern.is_empty() {
self.patterns.push(pattern.to_vec());
}
}
pub fn build(self) -> Automaton {
use std::collections::HashSet;
if self.patterns.is_empty() {
return Automaton::empty();
}
let mut seen: HashSet<Vec<u8>> = HashSet::new();
let mut unique_patterns: Vec<&[u8]> = Vec::new();
for pattern in &self.patterns {
if seen.insert(pattern.clone()) {
unique_patterns.push(pattern.as_slice());
}
}
if unique_patterns.is_empty() {
return Automaton::empty();
}
match DoubleArrayAhoCorasick::new(unique_patterns) {
Ok(ac) => Automaton { inner: ac },
Err(_) => Automaton::empty(),
}
}
}
impl Default for AutomatonBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_automaton() {
let ac = Automaton::empty();
let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
assert!(matches.is_empty());
}
#[test]
fn test_build_with_patterns() {
let patterns: Vec<&[u8]> = vec![b"hello", b"world"];
let ac = Automaton::build(&patterns);
let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
assert_eq!(matches.len(), 2);
}
#[test]
fn test_token_boundary_filtering() {
let pattern: &[u8] = &[31, 49];
let ac = Automaton::build(&[pattern]);
let haystack: &[u8] = &[109, 31, 49, 74];
let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
assert!(
matches.is_empty(),
"Should not match across token boundaries"
);
}
#[test]
fn test_valid_token_match() {
let pattern: &[u8] = &[31, 49];
let ac = Automaton::build(&[pattern]);
let haystack: &[u8] = &[0, 0, 31, 49, 0, 0];
let matches: Vec<_> = ac.find_overlapping_iter(haystack).collect();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].start, 2);
assert_eq!(matches[0].end, 4);
}
#[test]
fn test_builder() {
let mut builder = AutomatonBuilder::new();
builder.add_pattern(b"hello");
builder.add_pattern(b"world");
let ac = builder.build();
let matches: Vec<_> = ac.find_overlapping_iter(b"hello world").collect();
assert_eq!(matches.len(), 2);
}
#[test]
fn test_builder_empty_patterns() {
let builder = AutomatonBuilder::new();
let ac = builder.build();
let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
assert!(matches.is_empty());
}
#[test]
fn test_builder_skips_empty_patterns() {
let mut builder = AutomatonBuilder::new();
builder.add_pattern(b"");
builder.add_pattern(b"hello");
builder.add_pattern(b"");
let ac = builder.build();
let matches: Vec<_> = ac.find_overlapping_iter(b"hello").collect();
assert_eq!(matches.len(), 1);
}
#[test]
fn test_serialize_deserialize() {
let patterns: Vec<&[u8]> = vec![b"hello", b"world", b"test"];
let ac1 = Automaton::build(&patterns);
let serialized = ac1.inner.serialize();
let ac2 = Automaton::deserialize_unchecked(&serialized);
let haystack = b"hello world test";
let matches1: Vec<_> = ac1.find_overlapping_iter(haystack).collect();
let matches2: Vec<_> = ac2.find_overlapping_iter(haystack).collect();
assert_eq!(matches1.len(), matches2.len());
for (m1, m2) in matches1.iter().zip(matches2.iter()) {
assert_eq!(m1.pattern, m2.pattern);
assert_eq!(m1.start, m2.start);
assert_eq!(m1.end, m2.end);
}
}
#[test]
fn test_overlapping_matches() {
let patterns: Vec<&[u8]> = vec![b"ab", b"bc", b"abc"];
let ac = Automaton::build(&patterns);
let matches: Vec<_> = ac.find_overlapping_iter(b"abc").collect();
assert!(matches.len() >= 2);
}
}