use super::types::{DfaPackFormat, PackedDfa, PatternError};
use super::{dfa_minimize, dfa_pack, nfa_to_dfa, regex_to_nfa};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Pattern<'a> {
Literal(&'a [u8]),
Regex(&'a str),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub struct AssembleOptions {
pub format: DfaPackFormat,
pub minimize: bool,
}
impl Default for AssembleOptions {
fn default() -> Self {
Self {
format: DfaPackFormat::Dense,
minimize: true,
}
}
}
#[inline]
pub fn dfa_assemble(
patterns: &[Pattern<'_>],
options: AssembleOptions,
) -> Result<PackedDfa, PatternError> {
if patterns.is_empty()
|| patterns
.iter()
.any(|pattern| matches!(pattern, Pattern::Literal(bytes) if bytes.is_empty()))
{
return Err(PatternError::EmptyPatternSet);
}
let combined_regex = combined_pattern_regex(patterns);
let nfa = regex_to_nfa::regex_to_nfa(&combined_regex)?;
let dfa = nfa_to_dfa::nfa_to_dfa(&nfa)?;
let final_dfa = if options.minimize {
dfa_minimize::dfa_minimize(&dfa)
} else {
dfa
};
Ok(dfa_pack::dfa_pack(&final_dfa, options.format))
}
fn combined_pattern_regex(patterns: &[Pattern<'_>]) -> String {
let mut out = String::with_capacity(patterns.iter().map(pattern_regex_len_hint).sum());
for (index, pattern) in patterns.iter().enumerate() {
if index != 0 {
out.push('|');
}
push_pattern_regex(&mut out, pattern);
}
out
}
fn pattern_regex_len_hint(pattern: &Pattern<'_>) -> usize {
match pattern {
Pattern::Literal(bytes) => bytes.len().saturating_mul(4).max(1),
Pattern::Regex(source) => source.len(),
}
}
fn push_pattern_regex(out: &mut String, pattern: &Pattern<'_>) {
match pattern {
Pattern::Literal(bytes) => {
for &b in *bytes {
escape_literal_byte(out, b);
}
}
Pattern::Regex(source) => out.push_str(source),
}
}
fn escape_literal_byte(out: &mut String, byte: u8) {
if matches!(
byte,
b'.' | b'*' | b'+' | b'?' | b'|' | b'(' | b')' | b'[' | b']' | b'\\' | b'^' | b'$'
) {
out.push('\\');
out.push(byte as char);
} else if byte.is_ascii() {
out.push(byte as char);
} else {
out.push_str("[\\x");
push_hex_byte(out, byte);
out.push(']');
}
}
fn push_hex_byte(out: &mut String, byte: u8) {
const HEX: &[u8; 16] = b"0123456789ABCDEF";
out.push(HEX[(byte >> 4) as usize] as char);
out.push(HEX[(byte & 0x0F) as usize] as char);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pattern::{dfa_pack::dfa_unpack, types::INVALID_STATE};
fn accepts(packed: &PackedDfa, input: &[u8]) -> bool {
let dfa = dfa_unpack(packed).expect("unpack");
let mut state = dfa.start;
for &b in input {
let next = dfa.go(state, b);
if next == INVALID_STATE {
return false;
}
state = next;
}
dfa.accept[state as usize]
}
#[test]
fn single_literal_pattern() {
let patterns = [Pattern::Literal(b"hello")];
let packed = dfa_assemble(&patterns, AssembleOptions::default()).unwrap();
assert!(accepts(&packed, b"hello"));
assert!(!accepts(&packed, b"hell"));
assert!(!accepts(&packed, b"world"));
}
#[test]
fn multiple_literal_patterns() {
let patterns = [Pattern::Literal(b"foo"), Pattern::Literal(b"bar")];
let packed = dfa_assemble(&patterns, AssembleOptions::default()).unwrap();
assert!(accepts(&packed, b"foo"));
assert!(accepts(&packed, b"bar"));
assert!(!accepts(&packed, b"baz"));
}
#[test]
fn regex_pattern() {
let patterns = [Pattern::Regex("a(b|c)*d")];
let packed = dfa_assemble(&patterns, AssembleOptions::default()).unwrap();
assert!(accepts(&packed, b"ad"));
assert!(accepts(&packed, b"abd"));
assert!(accepts(&packed, b"abcbcd"));
assert!(!accepts(&packed, b"a"));
}
#[test]
fn literal_escapes_metachars() {
let patterns = [Pattern::Literal(b"a.b*c")];
let packed = dfa_assemble(&patterns, AssembleOptions::default()).unwrap();
assert!(accepts(&packed, b"a.b*c"));
assert!(!accepts(&packed, b"axbxc"));
}
#[test]
fn literal_non_ascii_byte_stays_single_byte() {
let patterns = [Pattern::Literal(&[0xE9])];
let packed = dfa_assemble(&patterns, AssembleOptions::default()).unwrap();
assert!(accepts(&packed, &[0xE9]));
assert!(!accepts(&packed, "\u{00E9}".as_bytes()));
}
#[test]
fn mixed_literal_and_regex() {
let patterns = [Pattern::Literal(b"exact"), Pattern::Regex("[0-9]+")];
let packed = dfa_assemble(&patterns, AssembleOptions::default()).unwrap();
assert!(accepts(&packed, b"exact"));
assert!(accepts(&packed, b"12345"));
assert!(!accepts(&packed, b"hello"));
}
#[test]
fn empty_pattern_set_errors() {
let err = dfa_assemble(&[], AssembleOptions::default()).unwrap_err();
assert!(matches!(err, PatternError::EmptyPatternSet));
}
#[test]
fn empty_literal_pattern_errors() {
let err = dfa_assemble(&[Pattern::Literal(b"")], AssembleOptions::default()).unwrap_err();
assert!(matches!(err, PatternError::EmptyPatternSet));
}
#[test]
fn malformed_regex_errors() {
let patterns = [Pattern::Regex("(unclosed")];
assert!(matches!(
dfa_assemble(&patterns, AssembleOptions::default()),
Err(PatternError::ParseError { .. })
));
}
#[test]
fn equiv_class_format() {
let patterns = [Pattern::Literal(b"abc")];
let options = AssembleOptions {
format: DfaPackFormat::EquivClass,
minimize: true,
};
let packed = dfa_assemble(&patterns, options).unwrap();
assert_eq!(packed.format, DfaPackFormat::EquivClass);
assert!(accepts(&packed, b"abc"));
}
#[test]
fn minimize_off_preserves_language() {
let patterns = [Pattern::Literal(b"xy")];
let options = AssembleOptions {
format: DfaPackFormat::Dense,
minimize: false,
};
let packed = dfa_assemble(&patterns, options).unwrap();
assert!(accepts(&packed, b"xy"));
assert!(!accepts(&packed, b"xz"));
}
}