use core::iter::Enumerate;
use core::num::NonZeroU32;
use crate::bytewise::DoubleArrayAhoCorasick;
use crate::Match;
use crate::bytewise::ROOT_STATE_IDX;
use crate::utils::FromU32;
#[doc(hidden)]
pub struct U8SliceIterator<P> {
inner: P,
pos: usize,
}
impl<P> U8SliceIterator<P>
where
P: AsRef<[u8]>,
{
#[allow(clippy::missing_const_for_fn)]
pub(crate) fn new(inner: P) -> Self {
Self { inner, pos: 0 }
}
}
impl<P> Iterator for U8SliceIterator<P>
where
P: AsRef<[u8]>,
{
type Item = u8;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
let ret = *self.inner.as_ref().get(self.pos)?;
self.pos += 1;
Some(ret)
}
}
pub struct FindIterator<'a, P, V> {
pub(crate) pma: &'a DoubleArrayAhoCorasick<V>,
pub(crate) haystack: Enumerate<P>,
}
impl<P, V> Iterator for FindIterator<'_, P, V>
where
P: Iterator<Item = u8>,
V: Copy,
{
type Item = Match<V>;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
let mut state_id = ROOT_STATE_IDX;
for (pos, c) in self.haystack.by_ref() {
state_id = unsafe { self.pma.next_state_id_unchecked(state_id, c) };
if let Some(output_pos) = unsafe {
self.pma
.states
.get_unchecked(usize::from_u32(state_id))
.output_pos()
} {
let out = unsafe {
self.pma
.outputs
.get_unchecked(usize::from_u32(output_pos.get() - 1))
};
return Some(Match {
length: usize::from_u32(out.length()),
end: pos + 1,
value: out.value(),
});
}
}
None
}
}
pub struct FindOverlappingIterator<'a, P, V> {
pub(crate) pma: &'a DoubleArrayAhoCorasick<V>,
pub(crate) haystack: Enumerate<P>,
pub(crate) state_id: u32,
pub(crate) pos: usize,
pub(crate) output_pos: Option<NonZeroU32>,
}
impl<P, V> Iterator for FindOverlappingIterator<'_, P, V>
where
P: Iterator<Item = u8>,
V: Copy,
{
type Item = Match<V>;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
if let Some(output_pos) = self.output_pos {
let out = unsafe {
self.pma
.outputs
.get_unchecked(usize::from_u32(output_pos.get() - 1))
};
self.output_pos = out.parent();
return Some(Match {
length: usize::from_u32(out.length()),
end: self.pos,
value: out.value(),
});
}
for (pos, c) in self.haystack.by_ref() {
self.state_id = unsafe { self.pma.next_state_id_unchecked(self.state_id, c) };
if let Some(output_pos) = unsafe {
self.pma
.states
.get_unchecked(usize::from_u32(self.state_id))
.output_pos()
} {
self.pos = pos + 1;
let out = unsafe {
self.pma
.outputs
.get_unchecked(usize::from_u32(output_pos.get() - 1))
};
self.output_pos = out.parent();
return Some(Match {
length: usize::from_u32(out.length()),
end: self.pos,
value: out.value(),
});
}
}
None
}
}
pub struct FindOverlappingNoSuffixIterator<'a, P, V> {
pub(crate) pma: &'a DoubleArrayAhoCorasick<V>,
pub(crate) haystack: Enumerate<P>,
pub(crate) state_id: u32,
}
impl<P, V> Iterator for FindOverlappingNoSuffixIterator<'_, P, V>
where
P: Iterator<Item = u8>,
V: Copy,
{
type Item = Match<V>;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
for (pos, c) in self.haystack.by_ref() {
self.state_id = unsafe { self.pma.next_state_id_unchecked(self.state_id, c) };
if let Some(output_pos) = unsafe {
self.pma
.states
.get_unchecked(usize::from_u32(self.state_id))
.output_pos()
} {
let out = unsafe {
self.pma
.outputs
.get_unchecked(usize::from_u32(output_pos.get() - 1))
};
return Some(Match {
length: usize::from_u32(out.length()),
end: pos + 1,
value: out.value(),
});
}
}
None
}
}
pub struct LeftmostFindIterator<'a, P, V>
where
P: AsRef<[u8]>,
{
pub(crate) pma: &'a DoubleArrayAhoCorasick<V>,
pub(crate) haystack: P,
pub(crate) pos: usize,
}
impl<P, V> Iterator for LeftmostFindIterator<'_, P, V>
where
P: AsRef<[u8]>,
V: Copy,
{
type Item = Match<V>;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
let mut state_id = ROOT_STATE_IDX;
let mut last_output_pos: Option<NonZeroU32> = None;
let haystack = self.haystack.as_ref();
for (pos, &c) in haystack.iter().enumerate().skip(self.pos) {
state_id = unsafe { self.pma.next_state_id_leftmost_unchecked(state_id, c) };
if state_id == ROOT_STATE_IDX {
if let Some(output_pos) = last_output_pos {
let out = unsafe {
self.pma
.outputs
.get_unchecked(usize::from_u32(output_pos.get() - 1))
};
return Some(Match {
length: usize::from_u32(out.length()),
end: self.pos,
value: out.value(),
});
}
} else if let Some(output_pos) = unsafe {
self.pma
.states
.get_unchecked(usize::from_u32(state_id))
.output_pos()
} {
last_output_pos.replace(output_pos);
self.pos = pos + 1;
}
}
last_output_pos.map(|output_pos| {
let out = unsafe {
self.pma
.outputs
.get_unchecked(usize::from_u32(output_pos.get() - 1))
};
Match {
length: usize::from_u32(out.length()),
end: self.pos,
value: out.value(),
}
})
}
}
pub struct FindStepper<'a, V> {
pub(crate) pma: &'a DoubleArrayAhoCorasick<V>,
pub(crate) state_id: u32,
pub(crate) pos: usize,
}
impl<V> FindStepper<'_, V>
where
V: Copy,
{
#[inline(always)]
pub fn consume(&mut self, c: u8) -> Option<Match<V>> {
self.state_id = unsafe { self.pma.next_state_id_unchecked(self.state_id, c) };
self.pos += 1;
if let Some(output_pos) = unsafe {
self.pma
.states
.get_unchecked(usize::from_u32(self.state_id))
.output_pos()
} {
let out = unsafe {
self.pma
.outputs
.get_unchecked(usize::from_u32(output_pos.get() - 1))
};
self.state_id = ROOT_STATE_IDX;
return Some(Match {
length: usize::from_u32(out.length()),
end: self.pos,
value: out.value(),
});
}
None
}
}
pub struct FindOverlappingStepperIterator<'a, V> {
pub(crate) pma: &'a DoubleArrayAhoCorasick<V>,
pub(crate) pos: usize,
pub(crate) output_pos: Option<NonZeroU32>,
}
impl<V> Iterator for FindOverlappingStepperIterator<'_, V>
where
V: Copy,
{
type Item = Match<V>;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
if let Some(output_pos) = self.output_pos {
let out = unsafe {
self.pma
.outputs
.get_unchecked(usize::from_u32(output_pos.get() - 1))
};
self.output_pos = out.parent();
return Some(Match {
length: usize::from_u32(out.length()),
end: self.pos,
value: out.value(),
});
}
None
}
}
pub struct FindOverlappingStepper<'a, V> {
pub(crate) pma: &'a DoubleArrayAhoCorasick<V>,
pub(crate) state_id: u32,
pub(crate) pos: usize,
}
impl<'a, V> FindOverlappingStepper<'a, V>
where
V: Copy,
{
#[inline(always)]
pub fn consume(&mut self, c: u8) -> FindOverlappingStepperIterator<'a, V> {
self.state_id = unsafe { self.pma.next_state_id_unchecked(self.state_id, c) };
let output_pos = unsafe {
self.pma
.states
.get_unchecked(usize::from_u32(self.state_id))
.output_pos()
};
self.pos += 1;
FindOverlappingStepperIterator {
pma: self.pma,
pos: self.pos,
output_pos,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_overlapping_stepper_lifetime() {
let pma = DoubleArrayAhoCorasick::new(["a", "ab"]).unwrap();
let mut stepper = pma.find_overlapping_stepper();
let mut it1 = stepper.consume(b'a');
let mut it2 = stepper.consume(b'b');
assert_eq!(
Some(Match {
length: 1,
end: 1,
value: 0
}),
it1.next()
);
assert_eq!(
Some(Match {
length: 2,
end: 2,
value: 1
}),
it2.next()
);
}
}