#[cfg(feature = "dfa-build")]
use core::iter;
use core::{fmt, mem::size_of};
#[cfg(feature = "dfa-build")]
use alloc::{vec, vec::Vec};
#[cfg(feature = "dfa-build")]
use crate::dfa::dense::{self, BuildError};
use crate::{
dfa::{
automaton::{fmt_state_indicator, Automaton, StartError},
dense::Flags,
special::Special,
StartKind, DEAD,
},
util::{
alphabet::{ByteClasses, ByteSet},
escape::DebugByte,
int::{Pointer, Usize, U16, U32},
prefilter::Prefilter,
primitives::{PatternID, StateID},
search::Anchored,
start::{self, Start, StartByteMap},
wire::{self, DeserializeError, Endian, SerializeError},
},
};
const LABEL: &str = "rust-regex-automata-dfa-sparse";
const VERSION: u32 = 2;
#[derive(Clone)]
pub struct DFA<T> {
tt: Transitions<T>,
st: StartTable<T>,
special: Special,
pre: Option<Prefilter>,
quitset: ByteSet,
flags: Flags,
}
#[cfg(feature = "dfa-build")]
impl DFA<Vec<u8>> {
#[cfg(feature = "syntax")]
pub fn new(pattern: &str) -> Result<DFA<Vec<u8>>, BuildError> {
dense::Builder::new()
.build(pattern)
.and_then(|dense| dense.to_sparse())
}
#[cfg(feature = "syntax")]
pub fn new_many<P: AsRef<str>>(
patterns: &[P],
) -> Result<DFA<Vec<u8>>, BuildError> {
dense::Builder::new()
.build_many(patterns)
.and_then(|dense| dense.to_sparse())
}
}
#[cfg(feature = "dfa-build")]
impl DFA<Vec<u8>> {
pub fn always_match() -> Result<DFA<Vec<u8>>, BuildError> {
dense::DFA::always_match()?.to_sparse()
}
pub fn never_match() -> Result<DFA<Vec<u8>>, BuildError> {
dense::DFA::never_match()?.to_sparse()
}
pub(crate) fn from_dense<T: AsRef<[u32]>>(
dfa: &dense::DFA<T>,
) -> Result<DFA<Vec<u8>>, BuildError> {
let mut sparse = Vec::with_capacity(StateID::SIZE * dfa.state_len());
let mut remap: Vec<StateID> = vec![DEAD; dfa.state_len()];
for state in dfa.states() {
let pos = sparse.len();
remap[dfa.to_index(state.id())] = StateID::new(pos)
.map_err(|_| BuildError::too_many_states())?;
sparse.push(0);
sparse.push(0);
let mut transition_len = 0;
for (unit1, unit2, _) in state.sparse_transitions() {
match (unit1.as_u8(), unit2.as_u8()) {
(Some(b1), Some(b2)) => {
transition_len += 1;
sparse.push(b1);
sparse.push(b2);
}
(None, None) => {}
(Some(_), None) | (None, Some(_)) => {
unreachable!()
}
}
}
transition_len += 1;
sparse.push(0);
sparse.push(0);
assert_ne!(
transition_len, 0,
"transition length should be non-zero",
);
assert!(
transition_len <= 257,
"expected transition length {transition_len} to be <= 257",
);
let ntrans = if dfa.is_match_state(state.id()) {
transition_len | (1 << 15)
} else {
transition_len
};
wire::NE::write_u16(ntrans, &mut sparse[pos..]);
let zeros = usize::try_from(transition_len)
.unwrap()
.checked_mul(StateID::SIZE)
.unwrap();
sparse.extend(iter::repeat(0).take(zeros));
if dfa.is_match_state(state.id()) {
let plen = dfa.match_pattern_len(state.id());
let mut pos = sparse.len();
let zeros = size_of::<u32>()
.checked_mul(plen)
.unwrap()
.checked_add(size_of::<u32>())
.unwrap();
sparse.extend(iter::repeat(0).take(zeros));
wire::NE::write_u32(
plen.try_into().expect("pattern ID length fits in u32"),
&mut sparse[pos..],
);
pos += size_of::<u32>();
for &pid in dfa.pattern_id_slice(state.id()) {
pos += wire::write_pattern_id::<wire::NE>(
pid,
&mut sparse[pos..],
);
}
}
let accel = dfa.accelerator(state.id());
sparse.push(accel.len().try_into().unwrap());
sparse.extend_from_slice(accel);
}
let mut new = DFA {
tt: Transitions {
sparse,
classes: dfa.byte_classes().clone(),
state_len: dfa.state_len(),
pattern_len: dfa.pattern_len(),
},
st: StartTable::from_dense_dfa(dfa, &remap)?,
special: dfa.special().remap(|id| remap[dfa.to_index(id)]),
pre: dfa.get_prefilter().map(|p| p.clone()),
quitset: dfa.quitset().clone(),
flags: dfa.flags().clone(),
};
for old_state in dfa.states() {
let new_id = remap[dfa.to_index(old_state.id())];
let mut new_state = new.tt.state_mut(new_id);
let sparse = old_state.sparse_transitions();
for (i, (_, _, next)) in sparse.enumerate() {
let next = remap[dfa.to_index(next)];
new_state.set_next_at(i, next);
}
}
new.tt.sparse.shrink_to_fit();
new.st.table.shrink_to_fit();
debug!(
"created sparse DFA, memory usage: {} (dense memory usage: {})",
new.memory_usage(),
dfa.memory_usage(),
);
Ok(new)
}
}
impl<T: AsRef<[u8]>> DFA<T> {
pub fn as_ref<'a>(&'a self) -> DFA<&'a [u8]> {
DFA {
tt: self.tt.as_ref(),
st: self.st.as_ref(),
special: self.special,
pre: self.pre.clone(),
quitset: self.quitset,
flags: self.flags,
}
}
#[cfg(feature = "alloc")]
pub fn to_owned(&self) -> DFA<alloc::vec::Vec<u8>> {
DFA {
tt: self.tt.to_owned(),
st: self.st.to_owned(),
special: self.special,
pre: self.pre.clone(),
quitset: self.quitset,
flags: self.flags,
}
}
pub fn start_kind(&self) -> StartKind {
self.st.kind
}
pub fn starts_for_each_pattern(&self) -> bool {
self.st.pattern_len.is_some()
}
pub fn byte_classes(&self) -> &ByteClasses {
&self.tt.classes
}
pub fn memory_usage(&self) -> usize {
self.tt.memory_usage() + self.st.memory_usage()
}
}
impl<T: AsRef<[u8]>> DFA<T> {
#[cfg(feature = "dfa-build")]
pub fn to_bytes_little_endian(&self) -> Vec<u8> {
self.to_bytes::<wire::LE>()
}
#[cfg(feature = "dfa-build")]
pub fn to_bytes_big_endian(&self) -> Vec<u8> {
self.to_bytes::<wire::BE>()
}
#[cfg(feature = "dfa-build")]
pub fn to_bytes_native_endian(&self) -> Vec<u8> {
self.to_bytes::<wire::NE>()
}
#[cfg(feature = "dfa-build")]
fn to_bytes<E: Endian>(&self) -> Vec<u8> {
let mut buf = vec![0; self.write_to_len()];
self.write_to::<E>(&mut buf).unwrap();
buf
}
pub fn write_to_little_endian(
&self,
dst: &mut [u8],
) -> Result<usize, SerializeError> {
self.write_to::<wire::LE>(dst)
}
pub fn write_to_big_endian(
&self,
dst: &mut [u8],
) -> Result<usize, SerializeError> {
self.write_to::<wire::BE>(dst)
}
pub fn write_to_native_endian(
&self,
dst: &mut [u8],
) -> Result<usize, SerializeError> {
self.write_to::<wire::NE>(dst)
}
fn write_to<E: Endian>(
&self,
dst: &mut [u8],
) -> Result<usize, SerializeError> {
let mut nw = 0;
nw += wire::write_label(LABEL, &mut dst[nw..])?;
nw += wire::write_endianness_check::<E>(&mut dst[nw..])?;
nw += wire::write_version::<E>(VERSION, &mut dst[nw..])?;
nw += {
E::write_u32(0, &mut dst[nw..]);
size_of::<u32>()
};
nw += self.flags.write_to::<E>(&mut dst[nw..])?;
nw += self.tt.write_to::<E>(&mut dst[nw..])?;
nw += self.st.write_to::<E>(&mut dst[nw..])?;
nw += self.special.write_to::<E>(&mut dst[nw..])?;
nw += self.quitset.write_to::<E>(&mut dst[nw..])?;
Ok(nw)
}
pub fn write_to_len(&self) -> usize {
wire::write_label_len(LABEL)
+ wire::write_endianness_check_len()
+ wire::write_version_len()
+ size_of::<u32>() + self.flags.write_to_len()
+ self.tt.write_to_len()
+ self.st.write_to_len()
+ self.special.write_to_len()
+ self.quitset.write_to_len()
}
}
impl<'a> DFA<&'a [u8]> {
pub fn from_bytes(
slice: &'a [u8],
) -> Result<(DFA<&'a [u8]>, usize), DeserializeError> {
let (dfa, nread) = unsafe { DFA::from_bytes_unchecked(slice)? };
let seen = dfa.tt.validate(&dfa.special)?;
dfa.st.validate(&dfa.special, &seen)?;
Ok((dfa, nread))
}
pub unsafe fn from_bytes_unchecked(
slice: &'a [u8],
) -> Result<(DFA<&'a [u8]>, usize), DeserializeError> {
let mut nr = 0;
nr += wire::read_label(&slice[nr..], LABEL)?;
nr += wire::read_endianness_check(&slice[nr..])?;
nr += wire::read_version(&slice[nr..], VERSION)?;
let _unused = wire::try_read_u32(&slice[nr..], "unused space")?;
nr += size_of::<u32>();
let (flags, nread) = Flags::from_bytes(&slice[nr..])?;
nr += nread;
let (tt, nread) = Transitions::from_bytes_unchecked(&slice[nr..])?;
nr += nread;
let (st, nread) = StartTable::from_bytes_unchecked(&slice[nr..])?;
nr += nread;
let (special, nread) = Special::from_bytes(&slice[nr..])?;
nr += nread;
if special.max.as_usize() >= tt.sparse().len() {
return Err(DeserializeError::generic(
"max should not be greater than or equal to sparse bytes",
));
}
let (quitset, nread) = ByteSet::from_bytes(&slice[nr..])?;
nr += nread;
let pre = None;
Ok((DFA { tt, st, special, pre, quitset, flags }, nr))
}
}
impl<T> DFA<T> {
pub fn set_prefilter(&mut self, prefilter: Option<Prefilter>) {
self.pre = prefilter
}
}
impl<T: AsRef<[u8]>> fmt::Debug for DFA<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "sparse::DFA(")?;
for state in self.tt.states() {
fmt_state_indicator(f, self, state.id())?;
writeln!(f, "{:06?}: {:?}", state.id().as_usize(), state)?;
}
writeln!(f, "")?;
for (i, (start_id, anchored, sty)) in self.st.iter().enumerate() {
if i % self.st.stride == 0 {
match anchored {
Anchored::No => writeln!(f, "START-GROUP(unanchored)")?,
Anchored::Yes => writeln!(f, "START-GROUP(anchored)")?,
Anchored::Pattern(pid) => writeln!(
f,
"START_GROUP(pattern: {:?})",
pid.as_usize()
)?,
}
}
writeln!(f, " {:?} => {:06?}", sty, start_id.as_usize())?;
}
writeln!(f, "state length: {:?}", self.tt.state_len)?;
writeln!(f, "pattern length: {:?}", self.pattern_len())?;
writeln!(f, "flags: {:?}", self.flags)?;
writeln!(f, ")")?;
Ok(())
}
}
unsafe impl<T: AsRef<[u8]>> Automaton for DFA<T> {
#[inline]
fn is_special_state(&self, id: StateID) -> bool {
self.special.is_special_state(id)
}
#[inline]
fn is_dead_state(&self, id: StateID) -> bool {
self.special.is_dead_state(id)
}
#[inline]
fn is_quit_state(&self, id: StateID) -> bool {
self.special.is_quit_state(id)
}
#[inline]
fn is_match_state(&self, id: StateID) -> bool {
self.special.is_match_state(id)
}
#[inline]
fn is_start_state(&self, id: StateID) -> bool {
self.special.is_start_state(id)
}
#[inline]
fn is_accel_state(&self, id: StateID) -> bool {
self.special.is_accel_state(id)
}
#[cfg_attr(feature = "perf-inline", inline(always))]
fn next_state(&self, current: StateID, input: u8) -> StateID {
let input = self.tt.classes.get(input);
self.tt.state(current).next(input)
}
#[inline]
unsafe fn next_state_unchecked(
&self,
current: StateID,
input: u8,
) -> StateID {
self.next_state(current, input)
}
#[inline]
fn next_eoi_state(&self, current: StateID) -> StateID {
self.tt.state(current).next_eoi()
}
#[inline]
fn pattern_len(&self) -> usize {
self.tt.pattern_len
}
#[inline]
fn match_len(&self, id: StateID) -> usize {
self.tt.state(id).pattern_len()
}
#[inline]
fn match_pattern(&self, id: StateID, match_index: usize) -> PatternID {
if self.tt.pattern_len == 1 {
return PatternID::ZERO;
}
self.tt.state(id).pattern_id(match_index)
}
#[inline]
fn has_empty(&self) -> bool {
self.flags.has_empty
}
#[inline]
fn is_utf8(&self) -> bool {
self.flags.is_utf8
}
#[inline]
fn is_always_start_anchored(&self) -> bool {
self.flags.is_always_start_anchored
}
#[inline]
fn start_state(
&self,
config: &start::Config,
) -> Result<StateID, StartError> {
let anchored = config.get_anchored();
let start = match config.get_look_behind() {
None => Start::Text,
Some(byte) => {
if !self.quitset.is_empty() && self.quitset.contains(byte) {
return Err(StartError::quit(byte));
}
self.st.start_map.get(byte)
}
};
self.st.start(anchored, start)
}
#[inline]
fn universal_start_state(&self, mode: Anchored) -> Option<StateID> {
match mode {
Anchored::No => self.st.universal_start_unanchored,
Anchored::Yes => self.st.universal_start_anchored,
Anchored::Pattern(_) => None,
}
}
#[inline]
fn accelerator(&self, id: StateID) -> &[u8] {
self.tt.state(id).accelerator()
}
#[inline]
fn get_prefilter(&self) -> Option<&Prefilter> {
self.pre.as_ref()
}
}
#[derive(Clone)]
struct Transitions<T> {
sparse: T,
classes: ByteClasses,
state_len: usize,
pattern_len: usize,
}
impl<'a> Transitions<&'a [u8]> {
unsafe fn from_bytes_unchecked(
mut slice: &'a [u8],
) -> Result<(Transitions<&'a [u8]>, usize), DeserializeError> {
let slice_start = slice.as_ptr().as_usize();
let (state_len, nr) =
wire::try_read_u32_as_usize(&slice, "state length")?;
slice = &slice[nr..];
let (pattern_len, nr) =
wire::try_read_u32_as_usize(&slice, "pattern length")?;
slice = &slice[nr..];
let (classes, nr) = ByteClasses::from_bytes(&slice)?;
slice = &slice[nr..];
let (len, nr) =
wire::try_read_u32_as_usize(&slice, "sparse transitions length")?;
slice = &slice[nr..];
wire::check_slice_len(slice, len, "sparse states byte length")?;
let sparse = &slice[..len];
slice = &slice[len..];
let trans = Transitions { sparse, classes, state_len, pattern_len };
Ok((trans, slice.as_ptr().as_usize() - slice_start))
}
}
impl<T: AsRef<[u8]>> Transitions<T> {
fn write_to<E: Endian>(
&self,
mut dst: &mut [u8],
) -> Result<usize, SerializeError> {
let nwrite = self.write_to_len();
if dst.len() < nwrite {
return Err(SerializeError::buffer_too_small(
"sparse transition table",
));
}
dst = &mut dst[..nwrite];
E::write_u32(u32::try_from(self.state_len).unwrap(), dst);
dst = &mut dst[size_of::<u32>()..];
E::write_u32(u32::try_from(self.pattern_len).unwrap(), dst);
dst = &mut dst[size_of::<u32>()..];
let n = self.classes.write_to(dst)?;
dst = &mut dst[n..];
E::write_u32(u32::try_from(self.sparse().len()).unwrap(), dst);
dst = &mut dst[size_of::<u32>()..];
let mut id = DEAD;
while id.as_usize() < self.sparse().len() {
let state = self.state(id);
let n = state.write_to::<E>(&mut dst)?;
dst = &mut dst[n..];
id = StateID::new(id.as_usize() + state.write_to_len()).unwrap();
}
Ok(nwrite)
}
fn write_to_len(&self) -> usize {
size_of::<u32>() + size_of::<u32>() + self.classes.write_to_len()
+ size_of::<u32>() + self.sparse().len()
}
fn validate(&self, sp: &Special) -> Result<Seen, DeserializeError> {
let mut verified = Seen::new();
let mut len = 0;
let mut id = DEAD;
while id.as_usize() < self.sparse().len() {
if sp.is_special_state(id) {
let is_actually_special = sp.is_dead_state(id)
|| sp.is_quit_state(id)
|| sp.is_match_state(id)
|| sp.is_start_state(id)
|| sp.is_accel_state(id);
if !is_actually_special {
return Err(DeserializeError::generic(
"found sparse state tagged as special but \
wasn't actually special",
));
}
}
let state = self.try_state(sp, id)?;
verified.insert(id);
id = StateID::new(wire::add(
id.as_usize(),
state.write_to_len(),
"next state ID offset",
)?)
.map_err(|err| {
DeserializeError::state_id_error(err, "next state ID offset")
})?;
len += 1;
}
for state in self.states() {
for i in 0..state.ntrans {
let to = state.next_at(i);
#[cfg(not(feature = "alloc"))]
{
let _ = self.try_state(sp, to)?;
}
#[cfg(feature = "alloc")]
{
if !verified.contains(&to) {
return Err(DeserializeError::generic(
"found transition that points to a \
non-existent state",
));
}
}
}
}
if len != self.state_len {
return Err(DeserializeError::generic(
"mismatching sparse state length",
));
}
Ok(verified)
}
fn as_ref(&self) -> Transitions<&'_ [u8]> {
Transitions {
sparse: self.sparse(),
classes: self.classes.clone(),
state_len: self.state_len,
pattern_len: self.pattern_len,
}
}
#[cfg(feature = "alloc")]
fn to_owned(&self) -> Transitions<alloc::vec::Vec<u8>> {
Transitions {
sparse: self.sparse().to_vec(),
classes: self.classes.clone(),
state_len: self.state_len,
pattern_len: self.pattern_len,
}
}
#[cfg_attr(feature = "perf-inline", inline(always))]
fn state(&self, id: StateID) -> State<'_> {
let mut state = &self.sparse()[id.as_usize()..];
let mut ntrans = wire::read_u16(&state).as_usize();
let is_match = (1 << 15) & ntrans != 0;
ntrans &= !(1 << 15);
state = &state[2..];
let (input_ranges, state) = state.split_at(ntrans * 2);
let (next, state) = state.split_at(ntrans * StateID::SIZE);
let (pattern_ids, state) = if is_match {
let npats = wire::read_u32(&state).as_usize();
state[4..].split_at(npats * 4)
} else {
(&[][..], state)
};
let accel_len = usize::from(state[0]);
let accel = &state[1..accel_len + 1];
State { id, is_match, ntrans, input_ranges, next, pattern_ids, accel }
}
fn try_state(
&self,
sp: &Special,
id: StateID,
) -> Result<State<'_>, DeserializeError> {
if id.as_usize() > self.sparse().len() {
return Err(DeserializeError::generic(
"invalid caller provided sparse state ID",
));
}
let mut state = &self.sparse()[id.as_usize()..];
let (mut ntrans, _) =
wire::try_read_u16_as_usize(state, "state transition length")?;
let is_match = ((1 << 15) & ntrans) != 0;
ntrans &= !(1 << 15);
state = &state[2..];
if ntrans > 257 || ntrans == 0 {
return Err(DeserializeError::generic(
"invalid transition length",
));
}
if is_match && !sp.is_match_state(id) {
return Err(DeserializeError::generic(
"state marked as match but not in match ID range",
));
} else if !is_match && sp.is_match_state(id) {
return Err(DeserializeError::generic(
"state in match ID range but not marked as match state",
));
}
let input_ranges_len = ntrans.checked_mul(2).unwrap();
wire::check_slice_len(state, input_ranges_len, "sparse byte pairs")?;
let (input_ranges, state) = state.split_at(input_ranges_len);
for pair in input_ranges.chunks(2) {
let (start, end) = (pair[0], pair[1]);
if start > end {
return Err(DeserializeError::generic("invalid input range"));
}
}
let next_len = ntrans
.checked_mul(self.id_len())
.expect("state size * #trans should always fit in a usize");
wire::check_slice_len(state, next_len, "sparse trans state IDs")?;
let (next, state) = state.split_at(next_len);
for idbytes in next.chunks(self.id_len()) {
let (id, _) =
wire::read_state_id(idbytes, "sparse state ID in try_state")?;
wire::check_slice_len(
self.sparse(),
id.as_usize(),
"invalid sparse state ID",
)?;
}
let (pattern_ids, state) = if is_match {
let (npats, nr) =
wire::try_read_u32_as_usize(state, "pattern ID length")?;
let state = &state[nr..];
if npats == 0 {
return Err(DeserializeError::generic(
"state marked as a match, but pattern length is zero",
));
}
let pattern_ids_len =
wire::mul(npats, 4, "sparse pattern ID byte length")?;
wire::check_slice_len(
state,
pattern_ids_len,
"sparse pattern IDs",
)?;
let (pattern_ids, state) = state.split_at(pattern_ids_len);
for patbytes in pattern_ids.chunks(PatternID::SIZE) {
wire::read_pattern_id(
patbytes,
"sparse pattern ID in try_state",
)?;
}
(pattern_ids, state)
} else {
(&[][..], state)
};
if is_match && pattern_ids.is_empty() {
return Err(DeserializeError::generic(
"state marked as a match, but has no pattern IDs",
));
}
if sp.is_match_state(id) && pattern_ids.is_empty() {
return Err(DeserializeError::generic(
"state marked special as a match, but has no pattern IDs",
));
}
if sp.is_match_state(id) != is_match {
return Err(DeserializeError::generic(
"whether state is a match or not is inconsistent",
));
}
if state.is_empty() {
return Err(DeserializeError::generic("no accelerator length"));
}
let (accel_len, state) = (usize::from(state[0]), &state[1..]);
if accel_len > 3 {
return Err(DeserializeError::generic(
"sparse invalid accelerator length",
));
} else if accel_len == 0 && sp.is_accel_state(id) {
return Err(DeserializeError::generic(
"got no accelerators in state, but in accelerator ID range",
));
} else if accel_len > 0 && !sp.is_accel_state(id) {
return Err(DeserializeError::generic(
"state in accelerator ID range, but has no accelerators",
));
}
wire::check_slice_len(
state,
accel_len,
"sparse corrupt accelerator length",
)?;
let (accel, _) = (&state[..accel_len], &state[accel_len..]);
let state = State {
id,
is_match,
ntrans,
input_ranges,
next,
pattern_ids,
accel,
};
if sp.is_quit_state(state.next_at(state.ntrans - 1)) {
return Err(DeserializeError::generic(
"state with EOI transition to quit state is illegal",
));
}
Ok(state)
}
fn states(&self) -> StateIter<'_, T> {
StateIter { trans: self, id: DEAD.as_usize() }
}
fn sparse(&self) -> &[u8] {
self.sparse.as_ref()
}
fn id_len(&self) -> usize {
StateID::SIZE
}
fn memory_usage(&self) -> usize {
self.sparse().len()
}
}
#[cfg(feature = "dfa-build")]
impl<T: AsMut<[u8]>> Transitions<T> {
fn state_mut(&mut self, id: StateID) -> StateMut<'_> {
let mut state = &mut self.sparse_mut()[id.as_usize()..];
let mut ntrans = wire::read_u16(&state).as_usize();
let is_match = (1 << 15) & ntrans != 0;
ntrans &= !(1 << 15);
state = &mut state[2..];
let (input_ranges, state) = state.split_at_mut(ntrans * 2);
let (next, state) = state.split_at_mut(ntrans * StateID::SIZE);
let (pattern_ids, state) = if is_match {
let npats = wire::read_u32(&state).as_usize();
state[4..].split_at_mut(npats * 4)
} else {
(&mut [][..], state)
};
let accel_len = usize::from(state[0]);
let accel = &mut state[1..accel_len + 1];
StateMut {
id,
is_match,
ntrans,
input_ranges,
next,
pattern_ids,
accel,
}
}
fn sparse_mut(&mut self) -> &mut [u8] {
self.sparse.as_mut()
}
}
#[derive(Clone)]
struct StartTable<T> {
table: T,
kind: StartKind,
start_map: StartByteMap,
stride: usize,
pattern_len: Option<usize>,
universal_start_unanchored: Option<StateID>,
universal_start_anchored: Option<StateID>,
}
#[cfg(feature = "dfa-build")]
impl StartTable<Vec<u8>> {
fn new<T: AsRef<[u32]>>(
dfa: &dense::DFA<T>,
pattern_len: Option<usize>,
) -> StartTable<Vec<u8>> {
let stride = Start::len();
let len = stride
.checked_mul(pattern_len.unwrap_or(0))
.unwrap()
.checked_add(stride.checked_mul(2).unwrap())
.unwrap()
.checked_mul(StateID::SIZE)
.unwrap();
StartTable {
table: vec![0; len],
kind: dfa.start_kind(),
start_map: dfa.start_map().clone(),
stride,
pattern_len,
universal_start_unanchored: dfa
.universal_start_state(Anchored::No),
universal_start_anchored: dfa.universal_start_state(Anchored::Yes),
}
}
fn from_dense_dfa<T: AsRef<[u32]>>(
dfa: &dense::DFA<T>,
remap: &[StateID],
) -> Result<StartTable<Vec<u8>>, BuildError> {
let start_pattern_len = if dfa.starts_for_each_pattern() {
Some(dfa.pattern_len())
} else {
None
};
let mut sl = StartTable::new(dfa, start_pattern_len);
for (old_start_id, anchored, sty) in dfa.starts() {
let new_start_id = remap[dfa.to_index(old_start_id)];
sl.set_start(anchored, sty, new_start_id);
}
if let Some(ref mut id) = sl.universal_start_anchored {
*id = remap[dfa.to_index(*id)];
}
if let Some(ref mut id) = sl.universal_start_unanchored {
*id = remap[dfa.to_index(*id)];
}
Ok(sl)
}
}
impl<'a> StartTable<&'a [u8]> {
unsafe fn from_bytes_unchecked(
mut slice: &'a [u8],
) -> Result<(StartTable<&'a [u8]>, usize), DeserializeError> {
let slice_start = slice.as_ptr().as_usize();
let (kind, nr) = StartKind::from_bytes(slice)?;
slice = &slice[nr..];
let (start_map, nr) = StartByteMap::from_bytes(slice)?;
slice = &slice[nr..];
let (stride, nr) =
wire::try_read_u32_as_usize(slice, "sparse start table stride")?;
slice = &slice[nr..];
if stride != Start::len() {
return Err(DeserializeError::generic(
"invalid sparse starting table stride",
));
}
let (maybe_pattern_len, nr) =
wire::try_read_u32_as_usize(slice, "sparse start table patterns")?;
slice = &slice[nr..];
let pattern_len = if maybe_pattern_len.as_u32() == u32::MAX {
None
} else {
Some(maybe_pattern_len)
};
if pattern_len.map_or(false, |len| len > PatternID::LIMIT) {
return Err(DeserializeError::generic(
"sparse invalid number of patterns",
));
}
let (universal_unanchored, nr) =
wire::try_read_u32(slice, "universal unanchored start")?;
slice = &slice[nr..];
let universal_start_unanchored = if universal_unanchored == u32::MAX {
None
} else {
Some(StateID::try_from(universal_unanchored).map_err(|e| {
DeserializeError::state_id_error(
e,
"universal unanchored start",
)
})?)
};
let (universal_anchored, nr) =
wire::try_read_u32(slice, "universal anchored start")?;
slice = &slice[nr..];
let universal_start_anchored = if universal_anchored == u32::MAX {
None
} else {
Some(StateID::try_from(universal_anchored).map_err(|e| {
DeserializeError::state_id_error(e, "universal anchored start")
})?)
};
let pattern_table_size = wire::mul(
stride,
pattern_len.unwrap_or(0),
"sparse invalid pattern length",
)?;
let start_state_len = wire::add(
wire::mul(2, stride, "start state stride too big")?,
pattern_table_size,
"sparse invalid 'any' pattern starts size",
)?;
let table_bytes_len = wire::mul(
start_state_len,
StateID::SIZE,
"sparse pattern table bytes length",
)?;
wire::check_slice_len(
slice,
table_bytes_len,
"sparse start ID table",
)?;
let table = &slice[..table_bytes_len];
slice = &slice[table_bytes_len..];
let sl = StartTable {
table,
kind,
start_map,
stride,
pattern_len,
universal_start_unanchored,
universal_start_anchored,
};
Ok((sl, slice.as_ptr().as_usize() - slice_start))
}
}
impl<T: AsRef<[u8]>> StartTable<T> {
fn write_to<E: Endian>(
&self,
mut dst: &mut [u8],
) -> Result<usize, SerializeError> {
let nwrite = self.write_to_len();
if dst.len() < nwrite {
return Err(SerializeError::buffer_too_small(
"sparse starting table ids",
));
}
dst = &mut dst[..nwrite];
let nw = self.kind.write_to::<E>(dst)?;
dst = &mut dst[nw..];
let nw = self.start_map.write_to(dst)?;
dst = &mut dst[nw..];
E::write_u32(u32::try_from(self.stride).unwrap(), dst);
dst = &mut dst[size_of::<u32>()..];
E::write_u32(
u32::try_from(self.pattern_len.unwrap_or(0xFFFF_FFFF)).unwrap(),
dst,
);
dst = &mut dst[size_of::<u32>()..];
E::write_u32(
self.universal_start_unanchored
.map_or(u32::MAX, |sid| sid.as_u32()),
dst,
);
dst = &mut dst[size_of::<u32>()..];
E::write_u32(
self.universal_start_anchored.map_or(u32::MAX, |sid| sid.as_u32()),
dst,
);
dst = &mut dst[size_of::<u32>()..];
for (sid, _, _) in self.iter() {
E::write_u32(sid.as_u32(), dst);
dst = &mut dst[StateID::SIZE..];
}
Ok(nwrite)
}
fn write_to_len(&self) -> usize {
self.kind.write_to_len()
+ self.start_map.write_to_len()
+ size_of::<u32>() + size_of::<u32>() + size_of::<u32>() + size_of::<u32>() + self.table().len()
}
fn validate(
&self,
sp: &Special,
seen: &Seen,
) -> Result<(), DeserializeError> {
for (id, _, _) in self.iter() {
if !seen.contains(&id) {
return Err(DeserializeError::generic(
"found invalid start state ID",
));
}
if sp.is_match_state(id) {
return Err(DeserializeError::generic(
"start states cannot be match states",
));
}
}
Ok(())
}
fn as_ref(&self) -> StartTable<&'_ [u8]> {
StartTable {
table: self.table(),
kind: self.kind,
start_map: self.start_map.clone(),
stride: self.stride,
pattern_len: self.pattern_len,
universal_start_unanchored: self.universal_start_unanchored,
universal_start_anchored: self.universal_start_anchored,
}
}
#[cfg(feature = "alloc")]
fn to_owned(&self) -> StartTable<alloc::vec::Vec<u8>> {
StartTable {
table: self.table().to_vec(),
kind: self.kind,
start_map: self.start_map.clone(),
stride: self.stride,
pattern_len: self.pattern_len,
universal_start_unanchored: self.universal_start_unanchored,
universal_start_anchored: self.universal_start_anchored,
}
}
fn start(
&self,
anchored: Anchored,
start: Start,
) -> Result<StateID, StartError> {
let start_index = start.as_usize();
let index = match anchored {
Anchored::No => {
if !self.kind.has_unanchored() {
return Err(StartError::unsupported_anchored(anchored));
}
start_index
}
Anchored::Yes => {
if !self.kind.has_anchored() {
return Err(StartError::unsupported_anchored(anchored));
}
self.stride + start_index
}
Anchored::Pattern(pid) => {
let len = match self.pattern_len {
None => {
return Err(StartError::unsupported_anchored(anchored))
}
Some(len) => len,
};
if pid.as_usize() >= len {
return Ok(DEAD);
}
(2 * self.stride)
+ (self.stride * pid.as_usize())
+ start_index
}
};
let start = index * StateID::SIZE;
Ok(wire::read_state_id_unchecked(&self.table()[start..]).0)
}
fn iter(&self) -> StartStateIter<'_, T> {
StartStateIter { st: self, i: 0 }
}
fn len(&self) -> usize {
self.table().len() / StateID::SIZE
}
fn table(&self) -> &[u8] {
self.table.as_ref()
}
fn memory_usage(&self) -> usize {
self.table().len()
}
}
#[cfg(feature = "dfa-build")]
impl<T: AsMut<[u8]>> StartTable<T> {
fn set_start(&mut self, anchored: Anchored, start: Start, id: StateID) {
let start_index = start.as_usize();
let index = match anchored {
Anchored::No => start_index,
Anchored::Yes => self.stride + start_index,
Anchored::Pattern(pid) => {
let pid = pid.as_usize();
let len = self
.pattern_len
.expect("start states for each pattern enabled");
assert!(pid < len, "invalid pattern ID {pid:?}");
self.stride
.checked_mul(pid)
.unwrap()
.checked_add(self.stride.checked_mul(2).unwrap())
.unwrap()
.checked_add(start_index)
.unwrap()
}
};
let start = index * StateID::SIZE;
let end = start + StateID::SIZE;
wire::write_state_id::<wire::NE>(
id,
&mut self.table.as_mut()[start..end],
);
}
}
struct StartStateIter<'a, T> {
st: &'a StartTable<T>,
i: usize,
}
impl<'a, T: AsRef<[u8]>> Iterator for StartStateIter<'a, T> {
type Item = (StateID, Anchored, Start);
fn next(&mut self) -> Option<(StateID, Anchored, Start)> {
let i = self.i;
if i >= self.st.len() {
return None;
}
self.i += 1;
let start_type = Start::from_usize(i % self.st.stride).unwrap();
let anchored = if i < self.st.stride {
Anchored::No
} else if i < (2 * self.st.stride) {
Anchored::Yes
} else {
let pid = (i - (2 * self.st.stride)) / self.st.stride;
Anchored::Pattern(PatternID::new(pid).unwrap())
};
let start = i * StateID::SIZE;
let end = start + StateID::SIZE;
let bytes = self.st.table()[start..end].try_into().unwrap();
let id = StateID::from_ne_bytes_unchecked(bytes);
Some((id, anchored, start_type))
}
}
impl<'a, T> fmt::Debug for StartStateIter<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("StartStateIter").field("i", &self.i).finish()
}
}
struct StateIter<'a, T> {
trans: &'a Transitions<T>,
id: usize,
}
impl<'a, T: AsRef<[u8]>> Iterator for StateIter<'a, T> {
type Item = State<'a>;
fn next(&mut self) -> Option<State<'a>> {
if self.id >= self.trans.sparse().len() {
return None;
}
let state = self.trans.state(StateID::new_unchecked(self.id));
self.id = self.id + state.write_to_len();
Some(state)
}
}
impl<'a, T> fmt::Debug for StateIter<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("StateIter").field("id", &self.id).finish()
}
}
#[derive(Clone)]
struct State<'a> {
id: StateID,
is_match: bool,
ntrans: usize,
input_ranges: &'a [u8],
next: &'a [u8],
pattern_ids: &'a [u8],
accel: &'a [u8],
}
impl<'a> State<'a> {
#[cfg_attr(feature = "perf-inline", inline(always))]
fn next(&self, input: u8) -> StateID {
for i in 0..(self.ntrans - 1) {
let (start, end) = self.range(i);
if start <= input && input <= end {
return self.next_at(i);
}
}
DEAD
}
fn next_eoi(&self) -> StateID {
self.next_at(self.ntrans - 1)
}
fn id(&self) -> StateID {
self.id
}
fn range(&self, i: usize) -> (u8, u8) {
(self.input_ranges[i * 2], self.input_ranges[i * 2 + 1])
}
fn next_at(&self, i: usize) -> StateID {
let start = i * StateID::SIZE;
let end = start + StateID::SIZE;
let bytes = self.next[start..end].try_into().unwrap();
StateID::from_ne_bytes_unchecked(bytes)
}
fn pattern_id(&self, match_index: usize) -> PatternID {
let start = match_index * PatternID::SIZE;
wire::read_pattern_id_unchecked(&self.pattern_ids[start..]).0
}
fn pattern_len(&self) -> usize {
assert_eq!(0, self.pattern_ids.len() % 4);
self.pattern_ids.len() / 4
}
fn accelerator(&self) -> &'a [u8] {
self.accel
}
fn write_to<E: Endian>(
&self,
mut dst: &mut [u8],
) -> Result<usize, SerializeError> {
let nwrite = self.write_to_len();
if dst.len() < nwrite {
return Err(SerializeError::buffer_too_small(
"sparse state transitions",
));
}
let ntrans =
if self.is_match { self.ntrans | (1 << 15) } else { self.ntrans };
E::write_u16(u16::try_from(ntrans).unwrap(), dst);
dst = &mut dst[size_of::<u16>()..];
dst[..self.input_ranges.len()].copy_from_slice(self.input_ranges);
dst = &mut dst[self.input_ranges.len()..];
for i in 0..self.ntrans {
E::write_u32(self.next_at(i).as_u32(), dst);
dst = &mut dst[StateID::SIZE..];
}
if self.is_match {
E::write_u32(u32::try_from(self.pattern_len()).unwrap(), dst);
dst = &mut dst[size_of::<u32>()..];
for i in 0..self.pattern_len() {
let pid = self.pattern_id(i);
E::write_u32(pid.as_u32(), dst);
dst = &mut dst[PatternID::SIZE..];
}
}
dst[0] = u8::try_from(self.accel.len()).unwrap();
dst[1..][..self.accel.len()].copy_from_slice(self.accel);
Ok(nwrite)
}
fn write_to_len(&self) -> usize {
let mut len = 2
+ (self.ntrans * 2)
+ (self.ntrans * StateID::SIZE)
+ (1 + self.accel.len());
if self.is_match {
len += size_of::<u32>() + self.pattern_ids.len();
}
len
}
}
impl<'a> fmt::Debug for State<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut printed = false;
for i in 0..(self.ntrans - 1) {
let next = self.next_at(i);
if next == DEAD {
continue;
}
if printed {
write!(f, ", ")?;
}
let (start, end) = self.range(i);
if start == end {
write!(f, "{:?} => {:?}", DebugByte(start), next.as_usize())?;
} else {
write!(
f,
"{:?}-{:?} => {:?}",
DebugByte(start),
DebugByte(end),
next.as_usize(),
)?;
}
printed = true;
}
let eoi = self.next_at(self.ntrans - 1);
if eoi != DEAD {
if printed {
write!(f, ", ")?;
}
write!(f, "EOI => {:?}", eoi.as_usize())?;
}
Ok(())
}
}
#[cfg(feature = "dfa-build")]
struct StateMut<'a> {
id: StateID,
is_match: bool,
ntrans: usize,
input_ranges: &'a mut [u8],
next: &'a mut [u8],
pattern_ids: &'a [u8],
accel: &'a mut [u8],
}
#[cfg(feature = "dfa-build")]
impl<'a> StateMut<'a> {
fn set_next_at(&mut self, i: usize, next: StateID) {
let start = i * StateID::SIZE;
let end = start + StateID::SIZE;
wire::write_state_id::<wire::NE>(next, &mut self.next[start..end]);
}
}
#[cfg(feature = "dfa-build")]
impl<'a> fmt::Debug for StateMut<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let state = State {
id: self.id,
is_match: self.is_match,
ntrans: self.ntrans,
input_ranges: self.input_ranges,
next: self.next,
pattern_ids: self.pattern_ids,
accel: self.accel,
};
fmt::Debug::fmt(&state, f)
}
}
#[derive(Debug)]
struct Seen {
#[cfg(feature = "alloc")]
set: alloc::collections::BTreeSet<StateID>,
#[cfg(not(feature = "alloc"))]
set: core::marker::PhantomData<StateID>,
}
#[cfg(feature = "alloc")]
impl Seen {
fn new() -> Seen {
Seen { set: alloc::collections::BTreeSet::new() }
}
fn insert(&mut self, id: StateID) {
self.set.insert(id);
}
fn contains(&self, id: &StateID) -> bool {
self.set.contains(id)
}
}
#[cfg(not(feature = "alloc"))]
impl Seen {
fn new() -> Seen {
Seen { set: core::marker::PhantomData }
}
fn insert(&mut self, _id: StateID) {}
fn contains(&self, _id: &StateID) -> bool {
true
}
}
#[cfg(all(test, feature = "syntax", feature = "dfa-build"))]
mod tests {
use crate::{
dfa::{dense::DFA, Automaton},
nfa::thompson,
Input, MatchError,
};
#[test]
fn heuristic_unicode_forward() {
let dfa = DFA::builder()
.configure(DFA::config().unicode_word_boundary(true))
.thompson(thompson::Config::new().reverse(true))
.build(r"\b[0-9]+\b")
.unwrap()
.to_sparse()
.unwrap();
let input = Input::new("β123").range(2..);
let expected = MatchError::quit(0xB2, 1);
let got = dfa.try_search_fwd(&input);
assert_eq!(Err(expected), got);
let input = Input::new("123β").range(..3);
let expected = MatchError::quit(0xCE, 3);
let got = dfa.try_search_fwd(&input);
assert_eq!(Err(expected), got);
}
#[test]
fn heuristic_unicode_reverse() {
let dfa = DFA::builder()
.configure(DFA::config().unicode_word_boundary(true))
.thompson(thompson::Config::new().reverse(true))
.build(r"\b[0-9]+\b")
.unwrap()
.to_sparse()
.unwrap();
let input = Input::new("β123").range(2..);
let expected = MatchError::quit(0xB2, 1);
let got = dfa.try_search_rev(&input);
assert_eq!(Err(expected), got);
let input = Input::new("123β").range(..3);
let expected = MatchError::quit(0xCE, 3);
let got = dfa.try_search_rev(&input);
assert_eq!(Err(expected), got);
}
}