use std::cmp;
use std::collections::{BTreeSet, VecDeque};
use std::fmt;
use std::mem::size_of;
use std::ops::{Index, IndexMut};
use ahocorasick::MatchKind;
use automaton::Automaton;
use classes::{ByteClassBuilder, ByteClasses};
use error::Result;
use prefilter::{self, opposite_ascii_case, Prefilter, PrefilterObj};
use state_id::{dead_id, fail_id, usize_to_state_id, StateID};
use Match;
pub type PatternID = usize;
pub type PatternLength = usize;
#[derive(Clone)]
pub struct NFA<S> {
match_kind: MatchKind,
start_id: S,
max_pattern_len: usize,
pattern_count: usize,
heap_bytes: usize,
prefilter: Option<PrefilterObj>,
anchored: bool,
byte_classes: ByteClasses,
states: Vec<State<S>>,
}
impl<S: StateID> NFA<S> {
pub fn byte_classes(&self) -> &ByteClasses {
&self.byte_classes
}
pub fn prefilter_obj(&self) -> Option<&PrefilterObj> {
self.prefilter.as_ref()
}
pub fn heap_bytes(&self) -> usize {
self.heap_bytes
+ self.prefilter.as_ref().map_or(0, |p| p.as_ref().heap_bytes())
}
pub fn max_pattern_len(&self) -> usize {
self.max_pattern_len
}
pub fn pattern_count(&self) -> usize {
self.pattern_count
}
pub fn state_len(&self) -> usize {
self.states.len()
}
pub fn matches(&self, id: S) -> &[(PatternID, PatternLength)] {
&self.states[id.to_usize()].matches
}
pub fn iter_all_transitions<F: FnMut(u8, S)>(
&self,
byte_classes: &ByteClasses,
id: S,
f: F,
) {
self.states[id.to_usize()].trans.iter_all(byte_classes, f);
}
pub fn failure_transition(&self, id: S) -> S {
self.states[id.to_usize()].fail
}
pub fn next_state(&self, current: S, input: u8) -> S {
self.states[current.to_usize()].next_state(input)
}
fn state(&self, id: S) -> &State<S> {
&self.states[id.to_usize()]
}
fn state_mut(&mut self, id: S) -> &mut State<S> {
&mut self.states[id.to_usize()]
}
fn start(&self) -> &State<S> {
self.state(self.start_id)
}
fn start_mut(&mut self) -> &mut State<S> {
let id = self.start_id;
self.state_mut(id)
}
fn iter_transitions_mut(&mut self, id: S) -> IterTransitionsMut<S> {
IterTransitionsMut::new(self, id)
}
fn copy_matches(&mut self, src: S, dst: S) {
let (src, dst) =
get_two_mut(&mut self.states, src.to_usize(), dst.to_usize());
dst.matches.extend_from_slice(&src.matches);
}
fn copy_empty_matches(&mut self, dst: S) {
let start_id = self.start_id;
self.copy_matches(start_id, dst);
}
fn add_dense_state(&mut self, depth: usize) -> Result<S> {
let trans = Transitions::Dense(Dense::new());
let id = usize_to_state_id(self.states.len())?;
self.states.push(State {
trans,
fail: if self.anchored { dead_id() } else { self.start_id },
depth: depth,
matches: vec![],
});
Ok(id)
}
fn add_sparse_state(&mut self, depth: usize) -> Result<S> {
let trans = Transitions::Sparse(vec![]);
let id = usize_to_state_id(self.states.len())?;
self.states.push(State {
trans,
fail: if self.anchored { dead_id() } else { self.start_id },
depth: depth,
matches: vec![],
});
Ok(id)
}
}
impl<S: StateID> Automaton for NFA<S> {
type ID = S;
fn match_kind(&self) -> &MatchKind {
&self.match_kind
}
fn anchored(&self) -> bool {
self.anchored
}
fn prefilter(&self) -> Option<&dyn Prefilter> {
self.prefilter.as_ref().map(|p| p.as_ref())
}
fn start_state(&self) -> S {
self.start_id
}
fn is_valid(&self, id: S) -> bool {
id.to_usize() < self.states.len()
}
fn is_match_state(&self, id: S) -> bool {
self.states[id.to_usize()].is_match()
}
fn get_match(
&self,
id: S,
match_index: usize,
end: usize,
) -> Option<Match> {
let state = match self.states.get(id.to_usize()) {
None => return None,
Some(state) => state,
};
state.matches.get(match_index).map(|&(id, len)| Match {
pattern: id,
len,
end,
})
}
fn match_count(&self, id: S) -> usize {
self.states[id.to_usize()].matches.len()
}
fn next_state(&self, mut current: S, input: u8) -> S {
loop {
let state = &self.states[current.to_usize()];
let next = state.next_state(input);
if next != fail_id() {
return next;
}
current = state.fail;
}
}
}
#[derive(Clone, Debug)]
pub struct State<S> {
trans: Transitions<S>,
fail: S,
matches: Vec<(PatternID, PatternLength)>,
depth: usize,
}
impl<S: StateID> State<S> {
fn heap_bytes(&self) -> usize {
self.trans.heap_bytes()
+ (self.matches.len() * size_of::<(PatternID, PatternLength)>())
}
fn add_match(&mut self, i: PatternID, len: PatternLength) {
self.matches.push((i, len));
}
fn is_match(&self) -> bool {
!self.matches.is_empty()
}
fn get_longest_match_len(&self) -> Option<usize> {
self.matches.get(0).map(|&(_, len)| len)
}
fn next_state(&self, input: u8) -> S {
self.trans.next_state(input)
}
fn set_next_state(&mut self, input: u8, next: S) {
self.trans.set_next_state(input, next);
}
}
#[derive(Clone, Debug)]
struct Dense<S>(Vec<S>);
impl<S> Dense<S>
where
S: StateID,
{
fn new() -> Self {
Dense(vec![fail_id(); 256])
}
#[inline]
fn len(&self) -> usize {
self.0.len()
}
}
impl<S> Index<u8> for Dense<S> {
type Output = S;
#[inline]
fn index(&self, i: u8) -> &S {
&self.0[i as usize]
}
}
impl<S> IndexMut<u8> for Dense<S> {
#[inline]
fn index_mut(&mut self, i: u8) -> &mut S {
&mut self.0[i as usize]
}
}
#[derive(Clone, Debug)]
enum Transitions<S> {
Sparse(Vec<(u8, S)>),
Dense(Dense<S>),
}
impl<S: StateID> Transitions<S> {
fn heap_bytes(&self) -> usize {
match *self {
Transitions::Sparse(ref sparse) => {
sparse.len() * size_of::<(u8, S)>()
}
Transitions::Dense(ref dense) => dense.len() * size_of::<S>(),
}
}
fn next_state(&self, input: u8) -> S {
match *self {
Transitions::Sparse(ref sparse) => {
for &(b, id) in sparse {
if b == input {
return id;
}
}
fail_id()
}
Transitions::Dense(ref dense) => dense[input],
}
}
fn set_next_state(&mut self, input: u8, next: S) {
match *self {
Transitions::Sparse(ref mut sparse) => {
match sparse.binary_search_by_key(&input, |&(b, _)| b) {
Ok(i) => sparse[i] = (input, next),
Err(i) => sparse.insert(i, (input, next)),
}
}
Transitions::Dense(ref mut dense) => {
dense[input] = next;
}
}
}
fn iter<F: FnMut(u8, S)>(&self, mut f: F) {
match *self {
Transitions::Sparse(ref sparse) => {
for &(b, id) in sparse {
f(b, id);
}
}
Transitions::Dense(ref dense) => {
for b in AllBytesIter::new() {
let id = dense[b];
if id != fail_id() {
f(b, id);
}
}
}
}
}
fn iter_all<F: FnMut(u8, S)>(&self, classes: &ByteClasses, mut f: F) {
if classes.is_singleton() {
match *self {
Transitions::Sparse(ref sparse) => {
sparse_iter(sparse, f);
}
Transitions::Dense(ref dense) => {
for b in AllBytesIter::new() {
f(b, dense[b]);
}
}
}
} else {
match *self {
Transitions::Sparse(ref sparse) => {
let mut last_class = None;
sparse_iter(sparse, |b, next| {
let class = classes.get(b);
if last_class != Some(class) {
last_class = Some(class);
f(b, next);
}
})
}
Transitions::Dense(ref dense) => {
for b in classes.representatives() {
f(b, dense[b]);
}
}
}
}
}
}
#[derive(Debug)]
struct IterTransitionsMut<'a, S: StateID + 'a> {
nfa: &'a mut NFA<S>,
state_id: S,
cur: usize,
}
impl<'a, S: StateID> IterTransitionsMut<'a, S> {
fn new(nfa: &'a mut NFA<S>, state_id: S) -> IterTransitionsMut<'a, S> {
IterTransitionsMut { nfa, state_id, cur: 0 }
}
fn nfa(&mut self) -> &mut NFA<S> {
self.nfa
}
}
impl<'a, S: StateID> Iterator for IterTransitionsMut<'a, S> {
type Item = (u8, S);
fn next(&mut self) -> Option<(u8, S)> {
match self.nfa.states[self.state_id.to_usize()].trans {
Transitions::Sparse(ref sparse) => {
if self.cur >= sparse.len() {
return None;
}
let i = self.cur;
self.cur += 1;
Some(sparse[i])
}
Transitions::Dense(ref dense) => {
while self.cur < dense.len() {
debug_assert!(self.cur < 256);
let b = self.cur as u8;
let id = dense[b];
self.cur += 1;
if id != fail_id() {
return Some((b, id));
}
}
None
}
}
}
}
#[derive(Clone, Debug)]
pub struct Builder {
dense_depth: usize,
match_kind: MatchKind,
prefilter: bool,
anchored: bool,
ascii_case_insensitive: bool,
}
impl Default for Builder {
fn default() -> Builder {
Builder {
dense_depth: 2,
match_kind: MatchKind::default(),
prefilter: true,
anchored: false,
ascii_case_insensitive: false,
}
}
}
impl Builder {
pub fn new() -> Builder {
Builder::default()
}
pub fn build<I, P, S: StateID>(&self, patterns: I) -> Result<NFA<S>>
where
I: IntoIterator<Item = P>,
P: AsRef<[u8]>,
{
Compiler::new(self)?.compile(patterns)
}
pub fn match_kind(&mut self, kind: MatchKind) -> &mut Builder {
self.match_kind = kind;
self
}
pub fn dense_depth(&mut self, depth: usize) -> &mut Builder {
self.dense_depth = depth;
self
}
pub fn prefilter(&mut self, yes: bool) -> &mut Builder {
self.prefilter = yes;
self
}
pub fn anchored(&mut self, yes: bool) -> &mut Builder {
self.anchored = yes;
self
}
pub fn ascii_case_insensitive(&mut self, yes: bool) -> &mut Builder {
self.ascii_case_insensitive = yes;
self
}
}
#[derive(Debug)]
struct Compiler<'a, S: StateID> {
builder: &'a Builder,
prefilter: prefilter::Builder,
nfa: NFA<S>,
byte_classes: ByteClassBuilder,
}
impl<'a, S: StateID> Compiler<'a, S> {
fn new(builder: &'a Builder) -> Result<Compiler<'a, S>> {
Ok(Compiler {
builder: builder,
prefilter: prefilter::Builder::new(builder.match_kind)
.ascii_case_insensitive(builder.ascii_case_insensitive),
nfa: NFA {
match_kind: builder.match_kind,
start_id: usize_to_state_id(2)?,
max_pattern_len: 0,
pattern_count: 0,
heap_bytes: 0,
prefilter: None,
anchored: builder.anchored,
byte_classes: ByteClasses::singletons(),
states: vec![],
},
byte_classes: ByteClassBuilder::new(),
})
}
fn compile<I, P>(mut self, patterns: I) -> Result<NFA<S>>
where
I: IntoIterator<Item = P>,
P: AsRef<[u8]>,
{
self.add_state(0)?; self.add_state(0)?; self.add_state(0)?; self.build_trie(patterns)?;
self.add_start_state_loop();
self.add_dead_state_loop();
if !self.builder.anchored {
if self.match_kind().is_leftmost() {
self.fill_failure_transitions_leftmost();
} else {
self.fill_failure_transitions_standard();
}
}
self.close_start_state_loop();
self.nfa.byte_classes = self.byte_classes.build();
if !self.builder.anchored {
self.nfa.prefilter = self.prefilter.build();
}
self.calculate_size();
Ok(self.nfa)
}
fn build_trie<I, P>(&mut self, patterns: I) -> Result<()>
where
I: IntoIterator<Item = P>,
P: AsRef<[u8]>,
{
'PATTERNS: for (pati, pat) in patterns.into_iter().enumerate() {
let pat = pat.as_ref();
self.nfa.max_pattern_len =
cmp::max(self.nfa.max_pattern_len, pat.len());
self.nfa.pattern_count += 1;
let mut prev = self.nfa.start_id;
let mut saw_match = false;
for (depth, &b) in pat.iter().enumerate() {
saw_match = saw_match || self.nfa.state(prev).is_match();
if self.builder.match_kind.is_leftmost_first() && saw_match {
continue 'PATTERNS;
}
self.byte_classes.set_range(b, b);
if self.builder.ascii_case_insensitive {
let b = opposite_ascii_case(b);
self.byte_classes.set_range(b, b);
}
let next = self.nfa.state(prev).next_state(b);
if next != fail_id() {
prev = next;
} else {
let next = self.add_state(depth + 1)?;
self.nfa.state_mut(prev).set_next_state(b, next);
if self.builder.ascii_case_insensitive {
let b = opposite_ascii_case(b);
self.nfa.state_mut(prev).set_next_state(b, next);
}
prev = next;
}
}
self.nfa.state_mut(prev).add_match(pati, pat.len());
if self.builder.prefilter {
self.prefilter.add(pat);
}
}
Ok(())
}
fn fill_failure_transitions_standard(&mut self) {
let mut queue = VecDeque::new();
let mut seen = self.queued_set();
for b in AllBytesIter::new() {
let next = self.nfa.start().next_state(b);
if next != self.nfa.start_id {
if !seen.contains(next) {
queue.push_back(next);
seen.insert(next);
}
}
}
while let Some(id) = queue.pop_front() {
let mut it = self.nfa.iter_transitions_mut(id);
while let Some((b, next)) = it.next() {
if !seen.contains(next) {
queue.push_back(next);
seen.insert(next);
}
let mut fail = it.nfa().state(id).fail;
while it.nfa().state(fail).next_state(b) == fail_id() {
fail = it.nfa().state(fail).fail;
}
fail = it.nfa().state(fail).next_state(b);
it.nfa().state_mut(next).fail = fail;
it.nfa().copy_matches(fail, next);
}
it.nfa().copy_empty_matches(id);
}
}
fn fill_failure_transitions_leftmost(&mut self) {
#[derive(Clone, Copy, Debug)]
struct QueuedState<S> {
id: S,
match_at_depth: Option<usize>,
}
impl<S: StateID> QueuedState<S> {
fn start(nfa: &NFA<S>) -> QueuedState<S> {
let match_at_depth =
if nfa.start().is_match() { Some(0) } else { None };
QueuedState { id: nfa.start_id, match_at_depth }
}
fn next_queued_state(
&self,
nfa: &NFA<S>,
id: S,
) -> QueuedState<S> {
let match_at_depth = self.next_match_at_depth(nfa, id);
QueuedState { id, match_at_depth }
}
fn next_match_at_depth(
&self,
nfa: &NFA<S>,
next: S,
) -> Option<usize> {
match self.match_at_depth {
Some(x) => return Some(x),
None if nfa.state(next).is_match() => {}
None => return None,
}
let depth = nfa.state(next).depth
- nfa.state(next).get_longest_match_len().unwrap()
+ 1;
Some(depth)
}
}
let mut queue: VecDeque<QueuedState<S>> = VecDeque::new();
let mut seen = self.queued_set();
let start = QueuedState::start(&self.nfa);
for b in AllBytesIter::new() {
let next_id = self.nfa.start().next_state(b);
if next_id != start.id {
let next = start.next_queued_state(&self.nfa, next_id);
if !seen.contains(next.id) {
queue.push_back(next);
seen.insert(next.id);
}
if self.nfa.state(next_id).is_match() {
self.nfa.state_mut(next_id).fail = dead_id();
}
}
}
while let Some(item) = queue.pop_front() {
let mut any_trans = false;
let mut it = self.nfa.iter_transitions_mut(item.id);
while let Some((b, next_id)) = it.next() {
any_trans = true;
let next = item.next_queued_state(it.nfa(), next_id);
if !seen.contains(next.id) {
queue.push_back(next);
seen.insert(next.id);
}
let mut fail = it.nfa().state(item.id).fail;
while it.nfa().state(fail).next_state(b) == fail_id() {
fail = it.nfa().state(fail).fail;
}
fail = it.nfa().state(fail).next_state(b);
if let Some(match_depth) = next.match_at_depth {
let fail_depth = it.nfa().state(fail).depth;
let next_depth = it.nfa().state(next.id).depth;
if next_depth - match_depth + 1 > fail_depth {
it.nfa().state_mut(next.id).fail = dead_id();
continue;
}
assert_ne!(
start.id,
it.nfa().state(next.id).fail,
"states that are match states or follow match \
states should never have a failure transition \
back to the start state in leftmost searching",
);
}
it.nfa().state_mut(next.id).fail = fail;
it.nfa().copy_matches(fail, next.id);
}
if !any_trans && it.nfa().state(item.id).is_match() {
it.nfa().state_mut(item.id).fail = dead_id();
}
}
}
fn queued_set(&self) -> QueuedSet<S> {
if self.builder.ascii_case_insensitive {
QueuedSet::active()
} else {
QueuedSet::inert()
}
}
fn add_start_state_loop(&mut self) {
let start_id = self.nfa.start_id;
let start = self.nfa.start_mut();
for b in AllBytesIter::new() {
if start.next_state(b) == fail_id() {
start.set_next_state(b, start_id);
}
}
}
fn close_start_state_loop(&mut self) {
if self.builder.anchored
|| (self.match_kind().is_leftmost() && self.nfa.start().is_match())
{
let start_id = self.nfa.start_id;
let start = self.nfa.start_mut();
for b in AllBytesIter::new() {
if start.next_state(b) == start_id {
start.set_next_state(b, dead_id());
}
}
}
}
fn add_dead_state_loop(&mut self) {
let dead = self.nfa.state_mut(dead_id());
for b in AllBytesIter::new() {
dead.set_next_state(b, dead_id());
}
}
fn calculate_size(&mut self) {
let mut size = 0;
for state in &self.nfa.states {
size += state.heap_bytes();
}
self.nfa.heap_bytes = size;
}
fn add_state(&mut self, depth: usize) -> Result<S> {
if depth < self.builder.dense_depth {
self.nfa.add_dense_state(depth)
} else {
self.nfa.add_sparse_state(depth)
}
}
fn match_kind(&self) -> MatchKind {
self.builder.match_kind
}
}
#[derive(Debug)]
struct QueuedSet<S> {
set: Option<BTreeSet<S>>,
}
impl<S: StateID> QueuedSet<S> {
fn inert() -> QueuedSet<S> {
QueuedSet { set: None }
}
fn active() -> QueuedSet<S> {
QueuedSet { set: Some(BTreeSet::new()) }
}
fn insert(&mut self, state_id: S) {
if let Some(ref mut set) = self.set {
set.insert(state_id);
}
}
fn contains(&self, state_id: S) -> bool {
match self.set {
None => false,
Some(ref set) => set.contains(&state_id),
}
}
}
#[derive(Debug)]
struct AllBytesIter(u16);
impl AllBytesIter {
fn new() -> AllBytesIter {
AllBytesIter(0)
}
}
impl Iterator for AllBytesIter {
type Item = u8;
fn next(&mut self) -> Option<Self::Item> {
if self.0 >= 256 {
None
} else {
let b = self.0 as u8;
self.0 += 1;
Some(b)
}
}
}
impl<S: StateID> fmt::Debug for NFA<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "NFA(")?;
writeln!(f, "match_kind: {:?}", self.match_kind)?;
writeln!(f, "{}", "-".repeat(79))?;
for (id, s) in self.states.iter().enumerate() {
let mut trans = vec![];
s.trans.iter(|byte, next| {
if id == self.start_id.to_usize() && next == self.start_id {
return;
}
if id == dead_id() {
return;
}
trans.push(format!("{} => {}", escape(byte), next.to_usize()));
});
writeln!(f, "{:04}: {}", id, trans.join(", "))?;
let matches: Vec<String> = s
.matches
.iter()
.map(|&(pattern_id, _)| pattern_id.to_string())
.collect();
writeln!(f, " matches: {}", matches.join(", "))?;
writeln!(f, " fail: {}", s.fail.to_usize())?;
writeln!(f, " depth: {}", s.depth)?;
}
writeln!(f, "{}", "-".repeat(79))?;
writeln!(f, ")")?;
Ok(())
}
}
fn sparse_iter<S: StateID, F: FnMut(u8, S)>(trans: &[(u8, S)], mut f: F) {
let mut byte = 0u16;
for &(b, id) in trans {
while byte < (b as u16) {
f(byte as u8, fail_id());
byte += 1;
}
f(b, id);
byte += 1;
}
for b in byte..256 {
f(b as u8, fail_id());
}
}
fn get_two_mut<T>(xs: &mut [T], i: usize, j: usize) -> (&mut T, &mut T) {
assert!(i != j, "{} must not be equal to {}", i, j);
if i < j {
let (before, after) = xs.split_at_mut(j);
(&mut before[i], &mut after[0])
} else {
let (before, after) = xs.split_at_mut(i);
(&mut after[0], &mut before[j])
}
}
fn escape(b: u8) -> String {
use std::ascii;
String::from_utf8(ascii::escape_default(b).collect::<Vec<_>>()).unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scratch() {
let nfa: NFA<usize> = Builder::new()
.dense_depth(0)
.match_kind(MatchKind::LeftmostFirst)
.build(&["abcdefg", "bcde", "bcdef"])
.unwrap();
println!("{:?}", nfa);
}
}