mod builder;
pub mod iter;
mod mapper;
use core::mem;
use core::num::NonZeroU32;
use alloc::vec::Vec;
pub use crate::charwise::builder::CharwiseDoubleArrayAhoCorasickBuilder;
use crate::charwise::iter::{
CharWithEndOffsetIterator, FindIterator, FindOverlappingIterator,
FindOverlappingNoSuffixIterator, FindOverlappingStepper, FindStepper, LeftmostFindIterator,
StrIterator,
};
use crate::charwise::mapper::CodeMapper;
use crate::errors::{DaachorseError, Result};
use crate::serializer::{Serializable, SerializableVec};
use crate::utils::FromU32;
use crate::{MatchKind, Output};
const ROOT_STATE_IDX: u32 = 0;
const DEAD_STATE_IDX: u32 = 1;
#[derive(Clone, Eq, Hash, PartialEq)]
pub struct CharwiseDoubleArrayAhoCorasick<V> {
states: Vec<State>,
mapper: CodeMapper,
outputs: Vec<Output<V>>,
match_kind: MatchKind,
num_states: u32,
}
impl<V> CharwiseDoubleArrayAhoCorasick<V> {
pub fn new<I, P>(patterns: I) -> Result<Self>
where
I: IntoIterator<Item = P>,
P: AsRef<str>,
V: Copy + TryFrom<usize>,
{
CharwiseDoubleArrayAhoCorasickBuilder::new().build(patterns)
}
pub fn with_values<I, P>(patvals: I) -> Result<Self>
where
I: IntoIterator<Item = (P, V)>,
P: AsRef<str>,
V: Copy,
{
CharwiseDoubleArrayAhoCorasickBuilder::new().build_with_values(patvals)
}
pub fn find_iter<P>(&self, haystack: P) -> FindIterator<'_, StrIterator<P>, V>
where
P: AsRef<str>,
{
assert!(
self.match_kind.is_standard(),
"Error: match_kind must be standard."
);
FindIterator {
pma: self,
haystack: unsafe { CharWithEndOffsetIterator::new(StrIterator::new(haystack)) },
}
}
pub unsafe fn find_iter_from_iter<P>(&self, haystack: P) -> FindIterator<'_, P, V>
where
P: Iterator<Item = u8>,
{
assert!(
self.match_kind.is_standard(),
"Error: match_kind must be standard."
);
FindIterator {
pma: self,
haystack: CharWithEndOffsetIterator::new(haystack),
}
}
pub fn find_overlapping_iter<P>(
&self,
haystack: P,
) -> FindOverlappingIterator<'_, StrIterator<P>, V>
where
P: AsRef<str>,
{
assert!(
self.match_kind.is_standard(),
"Error: match_kind must be standard."
);
FindOverlappingIterator {
pma: self,
haystack: unsafe { CharWithEndOffsetIterator::new(StrIterator::new(haystack)) },
state_id: ROOT_STATE_IDX,
pos: 0,
output_pos: None,
}
}
pub unsafe fn find_overlapping_iter_from_iter<P>(
&self,
haystack: P,
) -> FindOverlappingIterator<'_, P, V>
where
P: Iterator<Item = u8>,
{
assert!(
self.match_kind.is_standard(),
"Error: match_kind must be standard."
);
FindOverlappingIterator {
pma: self,
haystack: CharWithEndOffsetIterator::new(haystack),
state_id: ROOT_STATE_IDX,
pos: 0,
output_pos: None,
}
}
pub fn find_overlapping_no_suffix_iter<P>(
&self,
haystack: P,
) -> FindOverlappingNoSuffixIterator<'_, StrIterator<P>, V>
where
P: AsRef<str>,
{
assert!(
self.match_kind.is_standard(),
"Error: match_kind must be standard."
);
FindOverlappingNoSuffixIterator {
pma: self,
haystack: unsafe { CharWithEndOffsetIterator::new(StrIterator::new(haystack)) },
state_id: ROOT_STATE_IDX,
}
}
pub unsafe fn find_overlapping_no_suffix_iter_from_iter<P>(
&self,
haystack: P,
) -> FindOverlappingNoSuffixIterator<'_, P, V>
where
P: Iterator<Item = u8>,
{
assert!(
self.match_kind.is_standard(),
"Error: match_kind must be standard."
);
FindOverlappingNoSuffixIterator {
pma: self,
haystack: CharWithEndOffsetIterator::new(haystack),
state_id: ROOT_STATE_IDX,
}
}
pub fn leftmost_find_iter<P>(&self, haystack: P) -> LeftmostFindIterator<'_, P, V>
where
P: AsRef<str>,
{
assert!(
self.match_kind.is_leftmost(),
"Error: match_kind must be leftmost."
);
LeftmostFindIterator {
pma: self,
haystack,
pos: 0,
}
}
#[must_use]
pub fn find_stepper(&self) -> FindStepper<'_, V> {
assert!(
self.match_kind.is_standard(),
"Error: match_kind must be standard."
);
FindStepper {
pma: self,
state_id: ROOT_STATE_IDX,
pos: 0,
}
}
#[must_use]
pub fn find_overlapping_stepper(&self) -> FindOverlappingStepper<'_, V> {
assert!(
self.match_kind.is_standard(),
"Error: match_kind must be standard."
);
FindOverlappingStepper {
pma: self,
state_id: ROOT_STATE_IDX,
pos: 0,
}
}
#[must_use]
pub const fn match_kind(&self) -> MatchKind {
self.match_kind
}
#[must_use]
pub fn num_states(&self) -> usize {
usize::from_u32(self.num_states)
}
#[must_use]
pub fn num_elements(&self) -> usize {
self.states.len()
}
#[must_use]
pub fn heap_bytes(&self) -> usize {
self.states.len() * mem::size_of::<State>()
+ self.mapper.heap_bytes()
+ self.outputs.len() * mem::size_of::<Output<V>>()
}
#[must_use]
pub fn serialize(&self) -> Vec<u8>
where
V: Serializable,
{
let mut result = Vec::with_capacity(
self.states.serialized_bytes()
+ self.mapper.serialized_bytes()
+ self.outputs.serialized_bytes()
+ MatchKind::serialized_bytes()
+ u32::serialized_bytes(),
);
self.states.serialize_to_vec(&mut result);
self.mapper.serialize_to_vec(&mut result);
self.outputs.serialize_to_vec(&mut result);
self.match_kind.serialize_to_vec(&mut result);
self.num_states.serialize_to_vec(&mut result);
result
}
pub fn deserialize(source: &[u8]) -> Result<(Self, &[u8])>
where
V: Serializable,
{
let (states, source) = Vec::<State>::deserialize_from_slice(source)?;
let (mapper, source) = CodeMapper::deserialize_from_slice(source)?;
let (outputs, source) = Vec::<Output<V>>::deserialize_from_slice(source)?;
let (match_kind, source) = MatchKind::deserialize_from_slice(source)?;
let (num_states, source) = u32::deserialize_from_slice(source)?;
let pma = Self {
states,
mapper,
outputs,
match_kind,
num_states,
};
for &id in &pma.mapper.table {
if id == crate::charwise::mapper::INVALID_CODE {
continue;
}
if id >= pma.mapper.alphabet_size() {
return Err(DaachorseError::invalid_automaton());
}
}
let block_len = usize::from_u32(pma.mapper.alphabet_size().next_power_of_two().max(2));
if pma.states.is_empty() {
return Err(DaachorseError::invalid_automaton());
}
if pma.states.len() % block_len != 0 {
return Err(DaachorseError::invalid_automaton());
}
let states_len = pma.states.len();
let outputs_len = pma.outputs.len();
for state in &pma.states {
if let Some(base) = state.base() {
if usize::from_u32(base.get()) >= states_len {
return Err(DaachorseError::invalid_automaton());
}
};
if usize::from_u32(state.fail()) >= states_len {
return Err(DaachorseError::invalid_automaton());
}
if let Some(output_pos) = state.output_pos() {
if usize::from_u32(output_pos.get() - 1) >= outputs_len {
return Err(DaachorseError::invalid_automaton());
}
};
}
for (i, output) in pma.outputs.iter().enumerate() {
if let Some(parent) = output.parent {
if usize::from_u32(parent.get() - 1) >= i {
return Err(DaachorseError::invalid_automaton());
}
};
}
Ok((pma, source))
}
#[must_use]
pub unsafe fn deserialize_unchecked(source: &[u8]) -> (Self, &[u8])
where
V: Serializable,
{
let (states, source) = Vec::<State>::deserialize_from_slice(source).unwrap_unchecked();
let (mapper, source) = CodeMapper::deserialize_from_slice(source).unwrap_unchecked();
let (outputs, source) = Vec::<Output<V>>::deserialize_from_slice(source).unwrap_unchecked();
let (match_kind, source) = MatchKind::deserialize_from_slice(source).unwrap_unchecked();
let (num_states, source) = u32::deserialize_from_slice(source).unwrap_unchecked();
(
Self {
states,
mapper,
outputs,
match_kind,
num_states,
},
source,
)
}
#[allow(clippy::cast_possible_wrap)]
#[inline(always)]
unsafe fn child_index_unchecked(&self, state_id: u32, mapped_c: u32) -> Option<u32> {
let base = self
.states
.get_unchecked(usize::from_u32(state_id))
.base()?;
let child_idx = base.get() ^ mapped_c;
if self
.states
.get_unchecked(usize::from_u32(child_idx))
.check()
== state_id
{
Some(child_idx)
} else {
None
}
}
#[inline(always)]
unsafe fn next_state_id_unchecked(&self, mut state_id: u32, c: char) -> u32 {
if let Some(mapped_c) = self.mapper.get(c) {
loop {
if let Some(state_id) = self.child_index_unchecked(state_id, mapped_c) {
return state_id;
}
if state_id == ROOT_STATE_IDX {
return ROOT_STATE_IDX;
}
state_id = self.states.get_unchecked(usize::from_u32(state_id)).fail();
}
} else {
ROOT_STATE_IDX
}
}
#[inline(always)]
unsafe fn next_state_id_leftmost_unchecked(&self, mut state_id: u32, c: char) -> u32 {
if let Some(mapped_c) = self.mapper.get(c) {
loop {
if let Some(state_id) = self.child_index_unchecked(state_id, mapped_c) {
return state_id;
}
if state_id == ROOT_STATE_IDX {
return ROOT_STATE_IDX;
}
let fail_id = self.states.get_unchecked(usize::from_u32(state_id)).fail();
if fail_id == DEAD_STATE_IDX {
return ROOT_STATE_IDX;
}
state_id = fail_id;
}
} else {
ROOT_STATE_IDX
}
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
struct State {
base: Option<NonZeroU32>,
check: u32,
fail: u32,
output_pos: Option<NonZeroU32>,
}
impl Default for State {
fn default() -> Self {
Self {
base: None,
check: DEAD_STATE_IDX,
fail: DEAD_STATE_IDX,
output_pos: None,
}
}
}
impl State {
#[inline(always)]
pub const fn base(&self) -> Option<NonZeroU32> {
self.base
}
#[inline(always)]
pub const fn check(&self) -> u32 {
self.check
}
#[inline(always)]
pub const fn fail(&self) -> u32 {
self.fail
}
#[inline(always)]
pub const fn output_pos(&self) -> Option<NonZeroU32> {
self.output_pos
}
#[inline(always)]
#[allow(dead_code)]
pub fn set_base(&mut self, x: NonZeroU32) {
self.base = Some(x);
}
#[inline(always)]
#[allow(dead_code)]
pub fn set_check(&mut self, x: u32) {
self.check = x;
}
#[inline(always)]
#[allow(dead_code)]
pub fn set_fail(&mut self, x: u32) {
self.fail = x;
}
#[inline(always)]
#[allow(dead_code)]
pub fn set_output_pos(&mut self, x: Option<NonZeroU32>) {
self.output_pos = x;
}
}
impl Serializable for State {
#[inline(always)]
fn serialize_to_vec(&self, dst: &mut Vec<u8>) {
self.base.serialize_to_vec(dst);
self.check.serialize_to_vec(dst);
self.fail.serialize_to_vec(dst);
self.output_pos.serialize_to_vec(dst);
}
#[inline(always)]
fn deserialize_from_slice(src: &[u8]) -> Result<(Self, &[u8])> {
let (base, src) = Option::<NonZeroU32>::deserialize_from_slice(src)?;
let (check, src) = u32::deserialize_from_slice(src)?;
let (fail, src) = u32::deserialize_from_slice(src)?;
let (output_pos, src) = Option::<NonZeroU32>::deserialize_from_slice(src)?;
Ok((
Self {
base,
check,
fail,
output_pos,
},
src,
))
}
#[inline(always)]
fn serialized_bytes() -> usize {
Option::<NonZeroU32>::serialized_bytes()
+ u32::serialized_bytes()
+ u32::serialized_bytes()
+ Option::<NonZeroU32>::serialized_bytes()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_double_array() {
let patterns = vec!["AA", "AC", "BC", "C"];
let pma = CharwiseDoubleArrayAhoCorasick::<u32>::new(patterns).unwrap();
let base_expected = vec![
NonZeroU32::new(4), None, None, None, NonZeroU32::new(8), None, NonZeroU32::new(3), None, None, None, None, ];
let check_expected = vec![
1, 1, 6, 1, 0, 0, 0, 1, 4, 4, 1, ];
let fail_expected = vec![
ROOT_STATE_IDX, DEAD_STATE_IDX, 5, DEAD_STATE_IDX, ROOT_STATE_IDX, ROOT_STATE_IDX, ROOT_STATE_IDX, DEAD_STATE_IDX, 4, 5, DEAD_STATE_IDX, ];
let pma_base: Vec<_> = pma.states[0..11].iter().map(|state| state.base()).collect();
let pma_check: Vec<_> = pma.states[0..11]
.iter()
.map(|state| state.check())
.collect();
let pma_fail: Vec<_> = pma.states[0..11].iter().map(|state| state.fail()).collect();
assert_eq!(base_expected, pma_base);
assert_eq!(check_expected, pma_check);
assert_eq!(fail_expected, pma_fail);
}
#[test]
fn test_num_states() {
let patterns = vec!["abba", "baaba", "ababa"];
let pma = CharwiseDoubleArrayAhoCorasick::<u32>::new(patterns).unwrap();
assert_eq!(13, pma.num_states());
}
#[test]
fn test_input_order() {
let patvals_sorted = vec![("ababa", 0), ("abba", 1), ("baaba", 2)];
let patvals_unsorted = vec![("abba", 1), ("baaba", 2), ("ababa", 0)];
let pma_sorted = CharwiseDoubleArrayAhoCorasick::with_values(patvals_sorted).unwrap();
let pma_unsorted = CharwiseDoubleArrayAhoCorasick::with_values(patvals_unsorted).unwrap();
assert_eq!(pma_sorted.states, pma_unsorted.states);
assert_eq!(pma_sorted.outputs, pma_unsorted.outputs);
}
#[test]
fn test_n_blocks_1_1() {
let mut patterns = vec![];
for i in '\u{0}'..='\u{7d}' {
let pattern: alloc::string::String = core::iter::once(i).collect();
patterns.push(pattern);
}
let pma = CharwiseDoubleArrayAhoCorasick::<u32>::new(patterns).unwrap();
assert_eq!(127, pma.num_states());
assert_eq!(128, pma.states.len());
assert_eq!(0x7e, pma.states[0].base().unwrap().get());
}
#[test]
fn test_n_blocks_1_2() {
let mut patterns = vec![];
for i in '\u{0}'..='\u{7e}' {
let pattern: alloc::string::String = core::iter::once(i).collect();
patterns.push(pattern);
}
let pma = CharwiseDoubleArrayAhoCorasick::<u32>::new(patterns).unwrap();
assert_eq!(128, pma.num_states());
assert_eq!(256, pma.states.len());
assert_eq!(0x80, pma.states[0].base().unwrap().get());
}
#[test]
fn test_n_blocks_2_1() {
let mut patterns = vec![];
for i in '\u{0}'..='\u{7f}' {
let pattern: alloc::string::String = core::iter::once(i).collect();
patterns.push(pattern);
}
for i in '\u{0}'..='\u{7d}' {
let pattern = ['\u{0}', i].into_iter().collect();
patterns.push(pattern);
}
let pma = CharwiseDoubleArrayAhoCorasick::<u32>::new(patterns).unwrap();
assert_eq!(255, pma.num_states());
assert_eq!(256, pma.states.len());
assert_eq!(0x80, pma.states[0].base().unwrap().get());
assert_eq!(0x7e, pma.states[0x80].base().unwrap().get());
}
#[test]
fn test_n_blocks_2_2() {
let mut patterns = vec![];
for i in '\u{0}'..='\u{7f}' {
let pattern: alloc::string::String = core::iter::once(i).collect();
patterns.push(pattern);
}
for i in '\u{0}'..='\u{7e}' {
let pattern = ['\u{0}', i].into_iter().collect();
patterns.push(pattern);
}
let pma = CharwiseDoubleArrayAhoCorasick::<u32>::new(patterns).unwrap();
assert_eq!(256, pma.num_states());
assert_eq!(384, pma.states.len());
assert_eq!(0x80, pma.states[0].base().unwrap().get());
assert_eq!(0x100, pma.states[0x80].base().unwrap().get());
}
#[test]
fn test_serialize_state() {
let x = State {
base: NonZeroU32::new(42),
check: 57,
fail: 13,
output_pos: NonZeroU32::new(100),
};
let mut data = vec![];
x.serialize_to_vec(&mut data);
assert_eq!(data.len(), State::serialized_bytes());
let (y, rest) = State::deserialize_from_slice(&data).unwrap();
assert!(rest.is_empty());
assert_eq!(x, y);
}
#[test]
fn test_serialize_pma() {
let patterns = vec!["全世界", "世界", "に"];
let pma = CharwiseDoubleArrayAhoCorasick::<u32>::new(patterns).unwrap();
let bytes = pma.serialize();
let (other, rest) =
unsafe { CharwiseDoubleArrayAhoCorasick::deserialize_unchecked(&bytes) };
assert!(rest.is_empty());
assert_eq!(pma.states, other.states);
assert_eq!(pma.mapper, other.mapper);
assert_eq!(pma.outputs, other.outputs);
assert_eq!(pma.match_kind, other.match_kind);
assert_eq!(pma.num_states, other.num_states);
}
#[test]
fn test_deserialize_invalid_pma() {
let bytes = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ];
let pma = CharwiseDoubleArrayAhoCorasick::<u32>::deserialize(&bytes);
assert!(pma.is_err());
}
}