mod allocator;
mod iter;
pub use iter::{FindIter, OutputIter};
use alloc::vec::Vec;
use allocator::Allocator;
use crate::trie::Trie;
use crate::types::{Match, MatchKind, Output};
pub(crate) const BLOCK_LEN: u32 = 256;
pub(crate) const ROOT_IDX: u32 = 1;
#[derive(Debug, Clone, Copy)]
pub struct State {
pub base: u32,
pub check: u32,
pub fail: u32,
pub outpos: u32,
}
#[derive(Debug, Clone)]
pub struct DoubleArrayAhoCorasick {
pub states: Vec<State>,
pub outputs: Vec<Output>,
pub match_kind: MatchKind,
pub anchor: Option<u32>,
}
impl Default for DoubleArrayAhoCorasick {
fn default() -> Self {
Self::new()
}
}
impl DoubleArrayAhoCorasick {
pub fn new() -> Self {
let dead = State {
base: 0,
check: u32::MAX,
fail: 0,
outpos: u32::MAX,
};
let root = State {
base: 0,
check: u32::MAX,
fail: ROOT_IDX,
outpos: u32::MAX,
};
Self {
states: vec![dead, root],
outputs: Vec::new(),
match_kind: MatchKind::default(),
anchor: None,
}
}
fn is_vacant(&self, idx: u32) -> bool {
idx as usize >= self.states.len() || self.states[idx as usize].check == 0
}
fn is_valid_base(&self, base: u32, labels: &[u8]) -> bool {
labels.iter().all(|&c| self.is_vacant(base ^ (c as u32)))
}
fn find_base(&mut self, labels: &[u8], allocator: &mut Allocator) -> u32 {
for idx in allocator.iter() {
let base = idx ^ (labels[0] as u32);
if self.is_valid_base(base, labels) {
return base;
}
}
let old_size = self.states.len() as u32;
let num_blocks = (old_size + BLOCK_LEN - 1) / BLOCK_LEN;
let new_block_end = (num_blocks + 1) * BLOCK_LEN;
self.states.resize(
new_block_end as usize,
State {
base: 0,
check: 0,
fail: 0,
outpos: u32::MAX,
},
);
allocator.extend(old_size, new_block_end);
for idx in allocator.iter() {
let base = idx ^ (labels[0] as u32);
if self.is_valid_base(base, labels) {
return base;
}
}
panic!("Could not find valid base after extending");
}
fn build_recursive(
&mut self,
trie: &Trie,
trie_state: usize,
daac_state: u32,
mapping: &mut Vec<u32>,
allocator: &mut Allocator,
) {
let edges: Vec<(u8, u32)> = trie.states[trie_state]
.edges
.iter()
.map(|(&label, &child)| (label, child))
.collect();
if edges.is_empty() {
return;
}
let labels: Vec<u8> = edges.iter().map(|(l, _)| *l).collect();
let base = self.find_base(&labels, allocator);
self.states[daac_state as usize].base = base;
for &(label, trie_child) in &edges {
let daac_child = base ^ (label as u32);
if daac_child as usize >= self.states.len() {
let old_size = self.states.len() as u32;
self.states.resize(
daac_child as usize + 1,
State {
base: 0,
check: 0,
fail: 0,
outpos: u32::MAX,
},
);
allocator.extend(old_size, self.states.len() as u32);
}
self.states[daac_child as usize].check = daac_state;
allocator.delete(daac_child);
mapping[trie_child as usize] = daac_child;
}
for (label, trie_child) in edges {
let daac_child = base ^ (label as u32);
self.build_recursive(trie, trie_child as usize, daac_child, mapping, allocator);
}
}
pub fn from_trie(trie: Trie) -> Self {
use crate::trie::DEAD_STATE;
let mut daac = Self::new();
daac.match_kind = trie.match_kind;
let mut allocator = Allocator::new(daac.states.len());
let mut mapping = vec![0u32; trie.num_states()];
mapping[0] = ROOT_IDX;
daac.build_recursive(&trie, 0, ROOT_IDX, &mut mapping, &mut allocator);
for trie_id in 0..trie.num_states() {
let daac_id = mapping[trie_id];
let trie_fail = trie.states[trie_id].fail;
daac.states[daac_id as usize].fail = if trie_fail == DEAD_STATE {
DEAD_STATE
} else {
mapping[trie_fail as usize]
};
daac.states[daac_id as usize].outpos = trie.states[trie_id].outpos.unwrap_or(u32::MAX);
}
daac.anchor = trie.anchor.map(|a| mapping[a as usize]);
daac.outputs = trie.outputs;
daac
}
pub fn find_iter<'a>(&'a self, text: &'a [u8]) -> FindIter<'a> {
FindIter::new(self, text)
}
pub fn find(&self, text: &[u8]) -> Vec<Match> {
self.find_iter(text).collect()
}
#[inline]
pub fn start_state(&self) -> u32 {
ROOT_IDX
}
#[inline]
pub fn next_state(&self, mut state: u32, byte: u8) -> u32 {
use crate::trie::DEAD_STATE;
let states = &self.states;
let states_len = states.len();
loop {
let current = &states[state as usize];
let child = current.base ^ (byte as u32);
if (child as usize) < states_len && states[child as usize].check == state {
return child;
}
if state == ROOT_IDX {
return ROOT_IDX;
}
let fail = current.fail;
if fail == DEAD_STATE {
return ROOT_IDX;
}
state = fail;
}
}
#[inline]
pub fn outputs(&self, state: u32) -> iter::OutputIter<'_> {
let outpos = self.states[state as usize].outpos;
iter::OutputIter::new(&self.outputs, outpos)
}
#[inline]
pub fn consume(&self, state: u32, byte: u8) -> (u32, iter::OutputIter<'_>) {
let next = self.next_state(state, byte);
let outpos = self.states[next as usize].outpos;
(next, iter::OutputIter::new(&self.outputs, outpos))
}
pub fn serialize(&self) -> Vec<u8> {
use core::mem::size_of;
let state_bytes = self.states.len() * size_of::<State>();
let output_bytes = self.outputs.len() * size_of::<Output>();
let total = 4 + state_bytes + 4 + output_bytes + 1 + 4;
let mut buf = Vec::with_capacity(total);
buf.extend_from_slice(&(self.states.len() as u32).to_le_bytes());
for state in &self.states {
buf.extend_from_slice(&state.base.to_le_bytes());
buf.extend_from_slice(&state.check.to_le_bytes());
buf.extend_from_slice(&state.fail.to_le_bytes());
buf.extend_from_slice(&state.outpos.to_le_bytes());
}
buf.extend_from_slice(&(self.outputs.len() as u32).to_le_bytes());
for output in &self.outputs {
buf.extend_from_slice(&output.pattern_id.to_le_bytes());
buf.extend_from_slice(&output.length.to_le_bytes());
buf.extend_from_slice(&output.parent.to_le_bytes());
}
buf.push(match self.match_kind {
MatchKind::Overlapping => 0,
MatchKind::LeftmostFirst => 1,
MatchKind::LeftmostLongest => 2,
MatchKind::WordPiece => 3,
});
buf.extend_from_slice(&self.anchor.unwrap_or(u32::MAX).to_le_bytes());
buf
}
pub fn deserialize(data: &[u8]) -> Option<(Self, &[u8])> {
use core::mem::size_of;
if data.len() < 4 {
return None;
}
let mut pos = 0;
let num_states = u32::from_le_bytes(data[pos..pos + 4].try_into().ok()?) as usize;
pos += 4;
let state_bytes = num_states * size_of::<State>();
if data.len() < pos + state_bytes {
return None;
}
let mut states = Vec::with_capacity(num_states);
for i in 0..num_states {
let start = pos + i * size_of::<State>();
states.push(State {
base: u32::from_le_bytes(data[start..start + 4].try_into().ok()?),
check: u32::from_le_bytes(data[start + 4..start + 8].try_into().ok()?),
fail: u32::from_le_bytes(data[start + 8..start + 12].try_into().ok()?),
outpos: u32::from_le_bytes(data[start + 12..start + 16].try_into().ok()?),
});
}
pos += state_bytes;
if data.len() < pos + 4 {
return None;
}
let num_outputs = u32::from_le_bytes(data[pos..pos + 4].try_into().ok()?) as usize;
pos += 4;
let output_bytes = num_outputs * size_of::<Output>();
if data.len() < pos + output_bytes {
return None;
}
let mut outputs = Vec::with_capacity(num_outputs);
for i in 0..num_outputs {
let start = pos + i * size_of::<Output>();
outputs.push(Output {
pattern_id: u32::from_le_bytes(data[start..start + 4].try_into().ok()?),
length: u32::from_le_bytes(data[start + 4..start + 8].try_into().ok()?),
parent: u32::from_le_bytes(data[start + 8..start + 12].try_into().ok()?),
});
}
pos += output_bytes;
if data.len() < pos + 1 {
return None;
}
let match_kind = match data[pos] {
0 => MatchKind::Overlapping,
1 => MatchKind::LeftmostFirst,
2 => MatchKind::LeftmostLongest,
3 => MatchKind::WordPiece,
_ => return None,
};
pos += 1;
if data.len() < pos + 4 {
return None;
}
let anchor_val = u32::from_le_bytes(data[pos..pos + 4].try_into().ok()?);
let anchor = if anchor_val == u32::MAX { None } else { Some(anchor_val) };
pos += 4;
Some((
Self { states, outputs, match_kind, anchor },
&data[pos..],
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_creates_dead_and_root_states() {
let daac = DoubleArrayAhoCorasick::new();
assert_eq!(daac.states.len(), 2);
assert_eq!(daac.states[ROOT_IDX as usize].check, u32::MAX);
assert_eq!(daac.states[ROOT_IDX as usize].fail, ROOT_IDX);
}
#[test]
fn test_is_vacant() {
let daac = DoubleArrayAhoCorasick::new();
assert!(daac.is_vacant(100));
assert!(!daac.is_vacant(ROOT_IDX));
}
#[test]
fn test_from_trie_single_pattern() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
assert_eq!(daac.outputs.len(), 1);
assert_eq!(daac.outputs[0].pattern_id, 0);
assert_eq!(daac.outputs[0].length, 2);
let matches = daac.find(b"she");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pattern_id, 0);
}
#[test]
fn test_from_trie_multiple_patterns() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.add(b"hers", 2);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
assert_eq!(daac.outputs.len(), 3);
let matches = daac.find(b"ushers");
assert_eq!(matches.len(), 3);
let match_tuples: Vec<(u32, usize, usize)> = matches
.iter()
.map(|m| (m.pattern_id, m.start, m.end))
.collect();
assert!(match_tuples.contains(&(1, 1, 4))); assert!(match_tuples.contains(&(0, 2, 4))); assert!(match_tuples.contains(&(2, 2, 6))); }
#[test]
fn test_no_match() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"xyz");
assert_eq!(matches.len(), 0);
}
#[test]
fn test_fail_links_work() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"she");
assert_eq!(matches.len(), 2);
let match_ids: Vec<u32> = matches.iter().map(|m| m.pattern_id).collect();
assert!(match_ids.contains(&0)); assert!(match_ids.contains(&1)); }
#[test]
fn test_empty_trie() {
let trie = Trie::new();
let daac = DoubleArrayAhoCorasick::from_trie(trie);
assert!(daac.states.len() >= 2);
assert_eq!(daac.outputs.len(), 0);
}
#[test]
fn test_single_byte_patterns() {
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.add(b"b", 1);
trie.add(b"c", 2);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let matches = daac.find(b"abc");
assert_eq!(matches.len(), 3);
let match_ids: Vec<u32> = matches.iter().map(|m| m.pattern_id).collect();
assert!(match_ids.contains(&0)); assert!(match_ids.contains(&1)); assert!(match_ids.contains(&2)); }
#[test]
fn test_start_state() {
let daac = DoubleArrayAhoCorasick::new();
assert_eq!(daac.start_state(), ROOT_IDX);
}
#[test]
fn test_next_state_and_outputs() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let mut state = daac.start_state();
state = daac.next_state(state, b's');
let outputs: Vec<u32> = daac.outputs(state).map(|o| o.pattern_id).collect();
assert!(outputs.is_empty());
state = daac.next_state(state, b'h');
let outputs: Vec<u32> = daac.outputs(state).map(|o| o.pattern_id).collect();
assert!(outputs.is_empty());
state = daac.next_state(state, b'e');
let outputs: Vec<u32> = daac.outputs(state).map(|o| o.pattern_id).collect();
assert_eq!(outputs.len(), 2);
assert!(outputs.contains(&0)); assert!(outputs.contains(&1)); }
#[test]
fn test_outputs_empty_state() {
let mut trie = Trie::new();
trie.add(b"abc", 0);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let state = daac.start_state();
let outputs: Vec<_> = daac.outputs(state).collect();
assert!(outputs.is_empty());
}
#[test]
fn test_manual_traversal_matches_find() {
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.add(b"ab", 1);
trie.add(b"abc", 2);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let text = b"abc";
let mut manual_matches = Vec::new();
let mut state = daac.start_state();
for (pos, &byte) in text.iter().enumerate() {
state = daac.next_state(state, byte);
for output in daac.outputs(state) {
manual_matches.push((output.pattern_id, pos + 1 - output.length as usize, pos + 1));
}
}
let find_matches: Vec<_> = daac
.find(text)
.iter()
.map(|m| (m.pattern_id, m.start, m.end))
.collect();
assert_eq!(manual_matches.len(), find_matches.len());
for m in &manual_matches {
assert!(find_matches.contains(m));
}
}
#[test]
fn test_serialize_deserialize_roundtrip() {
let mut trie = Trie::new();
trie.add(b"he", 0);
trie.add(b"she", 1);
trie.add(b"hers", 2);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let bytes = daac.serialize();
let (restored, remainder) = DoubleArrayAhoCorasick::deserialize(&bytes).unwrap();
assert!(remainder.is_empty());
assert_eq!(daac.states.len(), restored.states.len());
assert_eq!(daac.outputs.len(), restored.outputs.len());
assert_eq!(daac.match_kind, restored.match_kind);
let original_matches = daac.find(b"ushers");
let restored_matches = restored.find(b"ushers");
assert_eq!(original_matches, restored_matches);
}
#[test]
fn test_serialize_deserialize_leftmost_first() {
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.add(b"aa", 1);
trie.add(b"aaa", 2);
trie.build(MatchKind::LeftmostFirst);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let bytes = daac.serialize();
let (restored, _) = DoubleArrayAhoCorasick::deserialize(&bytes).unwrap();
assert_eq!(restored.match_kind, MatchKind::LeftmostFirst);
let original_matches = daac.find(b"aaaa");
let restored_matches = restored.find(b"aaaa");
assert_eq!(original_matches, restored_matches);
}
#[test]
fn test_serialize_deserialize_empty() {
let daac = DoubleArrayAhoCorasick::new();
let bytes = daac.serialize();
let (restored, _) = DoubleArrayAhoCorasick::deserialize(&bytes).unwrap();
assert_eq!(daac.states.len(), restored.states.len());
assert_eq!(daac.outputs.len(), restored.outputs.len());
}
#[test]
fn test_deserialize_invalid_data() {
assert!(DoubleArrayAhoCorasick::deserialize(&[0, 1]).is_none());
let mut trie = Trie::new();
trie.add(b"a", 0);
trie.build(MatchKind::Overlapping);
let daac = DoubleArrayAhoCorasick::from_trie(trie);
let mut bytes = daac.serialize();
let match_kind_idx = bytes.len() - 5;
bytes[match_kind_idx] = 99; assert!(DoubleArrayAhoCorasick::deserialize(&bytes).is_none());
}
}