mod flat_contains;
mod prefix;
#[cfg(test)]
mod tests;
use std::borrow::Cow;
use flat_contains::FlatContainsDfa;
use fsst::ESCAPE_CODE;
use fsst::Symbol;
use prefix::FlatPrefixDfa;
use vortex_buffer::BitBuffer;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
pub(crate) struct FsstMatcher {
inner: MatcherInner,
}
enum MatcherInner {
MatchAll,
Prefix(FlatPrefixDfa),
Contains(FlatContainsDfa),
}
impl FsstMatcher {
pub(crate) fn try_new(
symbols: &[Symbol],
symbol_lengths: &[u8],
pattern: &[u8],
) -> VortexResult<Option<Self>> {
let Some(like_kind) = LikeKind::parse(pattern) else {
return Ok(None);
};
let inner = match like_kind {
LikeKind::Prefix(pattern) | LikeKind::Contains(pattern) if pattern.is_empty() => {
MatcherInner::MatchAll
}
LikeKind::Prefix(prefix) => {
if prefix.len() > FlatPrefixDfa::MAX_PREFIX_LEN {
return Ok(None);
}
MatcherInner::Prefix(FlatPrefixDfa::new(
symbols,
symbol_lengths,
prefix.as_ref(),
)?)
}
LikeKind::Contains(needle) => {
if needle.len() > FlatContainsDfa::MAX_NEEDLE_LEN {
return Ok(None);
}
MatcherInner::Contains(FlatContainsDfa::new(
symbols,
symbol_lengths,
needle.as_ref(),
)?)
}
};
Ok(Some(Self { inner }))
}
pub(crate) fn matches(&self, codes: &[u8]) -> bool {
match &self.inner {
MatcherInner::MatchAll => true,
MatcherInner::Prefix(dfa) => dfa.matches(codes),
MatcherInner::Contains(dfa) => dfa.matches(codes),
}
}
}
enum LikeKind<'a> {
Prefix(Cow<'a, [u8]>),
Contains(Cow<'a, [u8]>),
}
impl<'a> LikeKind<'a> {
fn parse(pattern: &'a [u8]) -> Option<Self> {
Self::parse_prefix(pattern).or_else(|| Self::parse_contains(pattern))
}
fn parse_prefix(pattern: &'a [u8]) -> Option<Self> {
Self::parse_literal_until_final_percent(pattern, 0).map(LikeKind::Prefix)
}
fn parse_contains(pattern: &'a [u8]) -> Option<Self> {
if !pattern.starts_with(b"%") {
return None;
}
Self::parse_literal_until_final_percent(pattern, 1).map(LikeKind::Contains)
}
fn parse_literal_until_final_percent(
pattern: &'a [u8],
literal_start: usize,
) -> Option<Cow<'a, [u8]>> {
let mut literal: Option<Vec<u8>> = None;
let mut idx = literal_start;
while idx < pattern.len() {
match pattern[idx] {
b'\\' => {
let escaped = pattern.get(idx + 1).copied().unwrap_or(b'\\');
literal
.get_or_insert_with(|| pattern[literal_start..idx].to_vec())
.push(escaped);
idx = (idx + 2).min(pattern.len());
}
b'%' if idx + 1 == pattern.len() => {
return Some(match literal {
Some(buf) => Cow::Owned(buf),
None => Cow::Borrowed(&pattern[literal_start..idx]),
});
}
b'%' | b'_' => return None,
byte => {
if let Some(literal) = &mut literal {
literal.push(byte);
}
idx += 1;
}
}
}
None
}
}
pub(crate) fn dfa_scan_to_bitbuf<T, F>(
n: usize,
offsets: &[T],
all_bytes: &[u8],
negated: bool,
matcher: F,
) -> BitBuffer
where
T: vortex_array::dtype::IntegerPType,
F: Fn(&[u8]) -> bool,
{
let mut start: usize = offsets[0].as_();
BitBuffer::collect_bool(n, |i| {
let end: usize = offsets[i + 1].as_();
let result = matcher(&all_bytes[start..end]) != negated;
start = end;
result
})
}
fn build_symbol_transitions(
symbols: &[Symbol],
symbol_lengths: &[u8],
byte_table: &[u8],
n_states: u8,
accept_state: u8,
) -> Vec<u8> {
let n_symbols = symbols.len();
let mut sym_trans = vec![0u8; n_states as usize * n_symbols];
for state in 0..n_states {
for code in 0..n_symbols {
if state == accept_state {
sym_trans[state as usize * n_symbols + code] = accept_state;
continue;
}
let sym = symbols[code].to_u64().to_le_bytes();
let sym_len = usize::from(symbol_lengths[code]);
let mut s = state;
for &b in &sym[..sym_len] {
if s == accept_state {
break;
}
s = byte_table[s as usize * 256 + b as usize];
}
sym_trans[state as usize * n_symbols + code] = s;
}
}
sym_trans
}
fn build_fused_table(
sym_trans: &[u8],
n_symbols: usize,
n_states: u8,
escape_value_fn: impl Fn(u8) -> u8,
default: u8,
) -> Vec<u8> {
let mut fused = vec![default; usize::from(n_states) * 256];
for state in 0..n_states {
let s = usize::from(state);
for code in 0..n_symbols {
fused[s * 256 + code] = sym_trans[s * n_symbols + code];
}
fused[s * 256 + usize::from(ESCAPE_CODE)] = escape_value_fn(state);
}
fused
}
fn kmp_byte_transitions(needle: &[u8]) -> Vec<u8> {
let n_states = u8::try_from(needle.len() + 1)
.vortex_expect("kmp_byte_transitions: must have needle.len() ≤ 255");
let accept = n_states - 1;
let failure = kmp_failure_table(needle);
let mut table = vec![0u8; n_states as usize * 256];
for state in 0..n_states {
for byte in 0..256usize {
if state == accept {
table[state as usize * 256 + byte] = accept;
continue;
}
let mut s = state;
loop {
if byte == usize::from(needle[usize::from(s)]) {
s += 1;
break;
}
if s == 0 {
break;
}
s = failure[usize::from(s) - 1];
}
table[state as usize * 256 + byte] = s;
}
}
table
}
fn kmp_failure_table(needle: &[u8]) -> Vec<u8> {
let mut failure = vec![0u8; needle.len()];
let mut k = 0u8;
for i in 1..needle.len() {
while k > 0 && needle[usize::from(k)] != needle[i] {
k = failure[usize::from(k) - 1];
}
if needle[usize::from(k)] == needle[i] {
k += 1;
}
failure[i] = k;
}
failure
}