#![allow(clippy::unwrap_used, clippy::panic)]
use dfajit::{JitDfa, TransitionTable};
use matchkit::Match;
fn reset_table(state_count: usize) -> TransitionTable {
let mut table = TransitionTable::new(state_count, 256).unwrap();
for state in 0..state_count {
for byte in u8::MIN..=u8::MAX {
table.set_transition(state, byte, 0);
}
}
table
}
fn scan_interpreted(table: &TransitionTable, input: &[u8]) -> Vec<Match> {
let mut state = 0u32;
let mut matches = Vec::new();
for (pos, &byte) in input.iter().enumerate() {
let idx = state as usize * table.class_count() + byte as usize;
let next = table.transitions().get(idx).copied().unwrap_or(0);
let mut found_match = false;
for &(accept_state, pattern_id) in table.accept_states() {
let clean_next = next & 0x7FFF_FFFF;
if clean_next == accept_state {
let end = (pos + 1) as u32;
let pat_len = table
.pattern_lengths()
.get(pattern_id as usize)
.copied()
.unwrap_or(0);
let start = end.saturating_sub(pat_len);
matches.push(Match::from_parts(pattern_id, start, end));
found_match = true;
}
}
if found_match {
state = 0;
} else {
state = next & 0x7FFF_FFFF;
}
}
matches
}
#[test]
fn compile_single_literal_pattern() {
let jit = JitDfa::from_patterns(&[b"hello"]).unwrap();
assert_eq!(jit.state_count(), 6); assert_eq!(jit.pattern_count(), 1);
let mut matches = vec![Match::from_parts(0, 0, 0); 10];
let count = jit.scan(b"hello world", &mut matches);
assert_eq!(count, 1);
assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 5);
}
#[test]
fn compile_multiple_literal_patterns() {
let jit = JitDfa::from_patterns(&[b"foo", b"bar", b"baz"]).unwrap();
let mut matches = vec![Match::from_parts(0, 0, 0); 10];
let count = jit.scan(b"foo bar baz", &mut matches);
assert_eq!(count, 3);
assert_eq!(matches[0].pattern_id, 0);
assert_eq!(matches[1].pattern_id, 1);
assert_eq!(matches[2].pattern_id, 2);
}
#[test]
fn compile_zero_patterns_error() {
let result = JitDfa::from_patterns(&[] as &[&[u8]]);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("empty") || err.contains("zero"),
"Error should indicate empty pattern set"
);
}
#[test]
fn compile_many_patterns_1000() {
let patterns: Vec<Vec<u8>> = (0..1000)
.map(|i| format!("pat{}", i).into_bytes())
.collect();
let pattern_refs: Vec<&[u8]> = patterns.iter().map(|p| p.as_slice()).collect();
let jit = JitDfa::from_patterns(&pattern_refs).unwrap();
assert_eq!(jit.pattern_count(), 1000);
let mut matches = vec![Match::from_parts(0, 0, 0); 10];
let count = jit.scan(b"pat0 pat999", &mut matches);
assert_eq!(count, 2);
}
#[test]
#[cfg(feature = "regex")]
fn compile_regex_pattern_a_plus_b() {
let jit = JitDfa::from_regex_patterns(&["a+b"]).unwrap();
let mut matches = Vec::new();
let count = jit.scan(b"ab aab aaab b", &mut matches);
assert_eq!(count, 3);
}
#[test]
#[cfg(feature = "regex")]
fn compile_regex_pattern_alpha_plus() {
let jit = JitDfa::from_regex_patterns(&["[a-z]+"]).unwrap();
let mut matches = Vec::new();
let count = jit.scan(b"hello123world", &mut matches);
assert!(count >= 2);
}
#[test]
#[cfg(feature = "regex")]
fn compile_regex_pattern_dot_star() {
let jit = JitDfa::from_regex_patterns(&[".*"]).unwrap();
let mut matches = Vec::new();
let count = jit.scan(b"anything", &mut matches);
assert!(count >= 1);
}
#[test]
#[cfg(feature = "regex")]
fn compile_regex_pattern_digit_class() {
let jit = JitDfa::from_regex_patterns(&["[0-9]+"]).unwrap();
let mut matches = Vec::new();
let count = jit.scan(b"abc123def456", &mut matches);
assert!(
count >= 1,
"Should find at least 1 digit match, found {}",
count
);
for m in &matches {
assert!(m.end <= 12, "Match end should be within input bounds");
assert!(m.start <= m.end, "Match start should be <= end");
}
}
#[test]
#[cfg(feature = "regex")]
fn compile_regex_multiple_patterns() {
let jit = JitDfa::from_regex_patterns(&["foo", "bar", "baz"]).unwrap();
let mut matches = Vec::new();
let count = jit.scan(b"foo bar baz qux", &mut matches);
assert_eq!(count, 3);
}
#[test]
#[cfg(feature = "regex")]
fn compile_state_explosion_pattern() {
let jit = JitDfa::from_regex_patterns(&["a?a?a?aaa"]).unwrap();
let mut matches = Vec::new();
let count = jit.scan(b"aaa aaaa aaaaa aaaaaa", &mut matches);
assert!(count >= 1, "Should find at least 1 match, found {}", count);
}
#[test]
#[cfg(not(feature = "regex"))]
fn compile_state_explosion_pattern() {
let patterns: Vec<Vec<u8>> = (0..100)
.map(|i| vec![b'a'; i + 1]) .collect();
let pattern_refs: Vec<&[u8]> = patterns.iter().map(|p| p.as_slice()).collect();
let jit = JitDfa::from_patterns(&pattern_refs).unwrap();
let mut matches = vec![Match::from_parts(0, 0, 0); 100];
let count = jit.scan(&vec![b'a'; 50], &mut matches);
assert!(
count >= 50,
"Should find at least 50 matches, found {}",
count
);
}
#[test]
fn compile_all_256_byte_values_character_class() {
let mut table = TransitionTable::new(2, 256).unwrap();
for byte in 0u8..=255u8 {
table.set_transition(0, byte, 1);
}
table.add_accept(1, 0);
table.set_pattern_length(0, 1);
let jit = JitDfa::compile(&table).unwrap();
let input: Vec<u8> = (0u8..=255u8).collect();
assert_eq!(jit.scan_count(&input), 256);
}
#[test]
fn compile_huge_dfa_falls_back_to_interpreted() {
let mut table = reset_table(5000);
table.set_transition(0, b'x', 1);
table.add_accept(1, 0);
table.set_pattern_length(0, 1);
let jit = JitDfa::compile(&table).unwrap();
assert_eq!(jit.state_count(), 5000);
assert_eq!(jit.scan_count(b"x"), 1);
}
#[test]
fn compile_dfa_exceeds_max_states_rejected() {
let result = TransitionTable::new(65_537, 256);
assert!(
result.is_err(),
"65537 states must be rejected by TransitionTable::new"
);
}
#[test]
fn jit_parity_empty_input() {
let table = {
let mut t = reset_table(3);
t.set_transition(0, b'a', 1);
t.set_transition(1, b'b', 2);
t.add_accept(2, 0);
t.set_pattern_length(0, 2);
t
};
let jit = JitDfa::compile(&table).unwrap();
let count = jit.scan_count(b"");
let interp_matches = scan_interpreted(&table, b"");
assert_eq!(count, interp_matches.len());
}
#[test]
fn jit_parity_single_byte_input() {
let table = {
let mut t = reset_table(2);
t.set_transition(0, b'x', 1);
t.add_accept(1, 0);
t.set_pattern_length(0, 1);
t
};
let jit = JitDfa::compile(&table).unwrap();
for byte in 0u8..=255u8 {
let input = [byte];
let count = jit.scan_count(&input);
let interp_matches = scan_interpreted(&table, &input);
assert_eq!(count, interp_matches.len(), "Mismatch for byte {}", byte);
}
}
#[test]
fn jit_parity_1mb_input() {
let mut table = reset_table(2);
for byte in 0u8..=255u8 {
table.set_transition(0, byte, 1);
}
table.add_accept(1, 0);
table.set_pattern_length(0, 1);
let jit = JitDfa::compile(&table).unwrap();
let input = vec![b'a'; 1_000_000];
let jit_count = jit.scan_count(&input);
let interp_matches = scan_interpreted(&table, &input);
assert_eq!(
jit_count,
interp_matches.len(),
"JIT and interpreted should match"
);
assert_eq!(jit_count, 1_000_000, "Every byte should produce a match");
}
#[test]
fn jit_parity_pattern_at_exact_start() {
let mut table = reset_table(4);
table.set_transition(0, b'h', 1);
table.set_transition(1, b'e', 2);
table.set_transition(2, b'y', 3);
table.add_accept(3, 0);
table.set_pattern_length(0, 3);
let jit = JitDfa::compile(&table).unwrap();
let jit_matches = {
let mut m = vec![Match::from_parts(0, 0, 0); 10];
let c = jit.scan(b"hey there", &mut m);
m.truncate(c);
m
};
let interp_matches = scan_interpreted(&table, b"hey there");
assert_eq!(jit_matches.len(), interp_matches.len());
assert_eq!(jit_matches[0].start, 0);
assert_eq!(jit_matches[0].end, 3);
}
#[test]
fn jit_parity_pattern_at_exact_end() {
let mut table = reset_table(4);
table.set_transition(0, b'e', 1);
table.set_transition(1, b'n', 2);
table.set_transition(2, b'd', 3);
table.add_accept(3, 0);
table.set_pattern_length(0, 3);
let jit = JitDfa::compile(&table).unwrap();
let jit_matches = {
let mut m = vec![Match::from_parts(0, 0, 0); 10];
let c = jit.scan(b"the end", &mut m);
m.truncate(c);
m
};
let interp_matches = scan_interpreted(&table, b"the end");
assert_eq!(jit_matches.len(), interp_matches.len());
assert_eq!(jit_matches[0].start, 4);
assert_eq!(jit_matches[0].end, 7);
}
#[test]
fn jit_parity_overlapping_patterns() {
let mut table = reset_table(3);
table.set_transition(0, b'a', 1);
table.set_transition(1, b'a', 2);
table.set_transition(2, b'a', 2); table.add_accept(2, 0);
table.set_pattern_length(0, 2);
let jit = JitDfa::compile(&table).unwrap();
let jit_matches = {
let mut m = vec![Match::from_parts(0, 0, 0); 10];
let c = jit.scan(b"aaaa", &mut m);
m.truncate(c);
m
};
let interp_matches = scan_interpreted(&table, b"aaaa");
assert_eq!(
jit_matches.len(),
interp_matches.len(),
"JIT found {} matches, interpreted found {}",
jit_matches.len(),
interp_matches.len()
);
assert_eq!(
jit_matches.len(),
2,
"Should find exactly 2 matches in 'aaaa'"
);
}
#[test]
fn jit_parity_all_bytes_random_input() {
let mut table = reset_table(2);
for byte in 0u8..=255u8 {
table.set_transition(0, byte, 1);
}
table.add_accept(1, 0);
table.set_pattern_length(0, 1);
let jit = JitDfa::compile(&table).unwrap();
let input: Vec<u8> = (0u8..=255u8)
.chain((0u8..=255u8).rev())
.chain(0u8..=127u8)
.collect();
let jit_count = jit.scan_count(&input);
let interp_matches = scan_interpreted(&table, &input);
assert_eq!(
jit_count,
interp_matches.len(),
"JIT count {} != interpreted count {}",
jit_count,
interp_matches.len()
);
assert_eq!(jit_count, input.len(), "Every byte should produce a match");
}
#[test]
#[cfg(feature = "regex")]
fn regex_with_start_anchor() {
let jit = JitDfa::from_regex_patterns(&["^hello"]).unwrap();
let mut matches = Vec::new();
let count = jit.scan(b"hello world", &mut matches);
assert!(count >= 1, "Should match at start");
let _count2 = jit.scan_count(b"say hello");
}
#[test]
#[cfg(feature = "regex")]
fn regex_with_end_anchor() {
let jit = JitDfa::from_regex_patterns(&["world$"]).unwrap();
let mut matches = Vec::new();
let count = jit.scan(b"hello world", &mut matches);
assert!(count >= 1, "Should match at end");
}
#[test]
#[cfg(feature = "regex")]
fn regex_with_start_and_end_anchor() {
let jit = JitDfa::from_regex_patterns(&["^hello$"]).unwrap();
let mut matches = Vec::new();
let count = jit.scan(b"hello", &mut matches);
assert!(count >= 1, "Should match exact string");
}
#[test]
fn pattern_longer_than_input() {
let jit = JitDfa::from_patterns(&[b"longpattern"]).unwrap();
let mut matches = Vec::new();
let count = jit.scan(b"short", &mut matches);
assert_eq!(count, 0, "Long pattern should not match short input");
}
#[test]
fn empty_pattern_handling() {
let jit = JitDfa::from_patterns(&[b"", b"x"]).unwrap();
let mut matches = vec![Match::from_parts(0, 0, 0); 10];
let count = jit.scan(b"x", &mut matches);
assert_eq!(count, 1, "Should find the 'x' pattern");
assert_eq!(
matches[0].pattern_id, 1,
"'x' should have pattern_id 1 (after empty pattern)"
);
}
#[test]
fn pattern_exactly_matching_input_length() {
let jit = JitDfa::from_patterns(&[b"exact"]).unwrap();
let mut matches = vec![Match::from_parts(0, 0, 0); 10];
let count = jit.scan(b"exact", &mut matches);
assert_eq!(count, 1);
assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 5);
}
#[test]
fn unicode_multibyte_boundary() {
let mut table = reset_table(3);
table.set_transition(0, 0xC3, 1);
table.set_transition(1, 0xA9, 2);
table.add_accept(2, 0);
table.set_pattern_length(0, 2);
let jit = JitDfa::compile(&table).unwrap();
let input = "café".as_bytes();
let mut matches = vec![Match::from_parts(0, 0, 0); 10];
let count = jit.scan(input, &mut matches);
assert_eq!(count, 1, "Should match UTF-8 sequence");
}
#[test]
fn null_bytes_in_input() {
let mut table = reset_table(3);
table.set_transition(0, 0x00, 1);
table.set_transition(1, 0x00, 2);
table.add_accept(2, 0);
table.set_pattern_length(0, 2);
let jit = JitDfa::compile(&table).unwrap();
let input = vec![0x00, 0x00, 0x00, 0x00];
let mut matches = vec![Match::from_parts(0, 0, 0); 10];
let count = jit.scan(&input, &mut matches);
assert_eq!(
count, 2,
"Should match two non-overlapping pairs of null bytes"
);
}
#[test]
fn high_bit_bytes() {
let mut table = reset_table(2);
for byte in 128u8..=255u8 {
table.set_transition(0, byte, 1);
}
table.add_accept(1, 0);
table.set_pattern_length(0, 1);
let jit = JitDfa::compile(&table).unwrap();
let input: Vec<u8> = (128u8..=255u8).collect();
assert_eq!(jit.scan_count(&input), 128);
}
#[test]
fn compute_ranges_digit_class() {
let mut table = reset_table(2);
for byte in b'0'..=b'9' {
table.set_transition(0, byte, 1);
}
let ranges = table.compute_ranges();
let state0_ranges = &ranges[0];
let digit_range = state0_ranges
.iter()
.find(|(_, hi, target)| *hi == b'9' && *target == 1);
assert!(digit_range.is_some(), "Should have range for digits");
let (lo, hi, target) = digit_range.unwrap();
assert_eq!(*lo, b'0');
assert_eq!(*hi, b'9');
assert_eq!(*target, 1);
}
#[test]
fn compute_ranges_alpha_class() {
let mut table = reset_table(2);
for byte in b'a'..=b'z' {
table.set_transition(0, byte, 1);
}
for byte in b'A'..=b'Z' {
table.set_transition(0, byte, 1);
}
let ranges = table.compute_ranges();
let state0_ranges = &ranges[0];
let alpha_ranges: Vec<_> = state0_ranges
.iter()
.filter(|(_, _, target)| *target == 1)
.collect();
assert_eq!(alpha_ranges.len(), 2);
assert_eq!(alpha_ranges[0], &(b'A', b'Z', 1));
assert_eq!(alpha_ranges[1], &(b'a', b'z', 1));
}
#[test]
fn compute_ranges_negated_newline() {
let mut table = reset_table(2);
for byte in 0u8..=255u8 {
if byte != b'\n' {
table.set_transition(0, byte, 1);
}
}
table.add_accept(1, 0);
let ranges = table.compute_ranges();
let state0_ranges = &ranges[0];
let nonzero_targets: Vec<_> = state0_ranges
.iter()
.filter(|(_, _, target)| *target == 1)
.collect();
assert_eq!(nonzero_targets.len(), 2);
assert_eq!(nonzero_targets[0], &(0x00, 0x09, 1));
assert_eq!(nonzero_targets[1], &(0x0B, 0xFF, 1));
}
#[test]
fn compute_ranges_single_byte_ranges() {
let mut table = TransitionTable::new(257, 256).unwrap();
for (i, byte) in (0u8..=255u8).enumerate() {
table.set_transition(0, byte, (i + 1) as u32);
}
let ranges = table.compute_ranges();
let state0_ranges = &ranges[0];
assert_eq!(state0_ranges.len(), 256);
}
#[test]
fn compute_ranges_single_range_all_same() {
let mut table = reset_table(2);
for byte in 0u8..=255u8 {
table.set_transition(0, byte, 1);
}
let ranges = table.compute_ranges();
let state0_ranges = &ranges[0];
assert_eq!(state0_ranges.len(), 1);
assert_eq!(state0_ranges[0], (0, 255, 1));
}
#[test]
fn compute_ranges_alternating_targets() {
let mut table = reset_table(3);
for byte in 0u8..=255u8 {
if byte % 2 == 0 {
table.set_transition(0, byte, 1);
} else {
table.set_transition(0, byte, 2);
}
}
let ranges = table.compute_ranges();
let state0_ranges = &ranges[0];
assert_eq!(state0_ranges.len(), 256);
}
#[test]
fn compute_ranges_empty_class_count() {
let result = TransitionTable::new(2, 0);
assert!(result.is_err());
}
#[test]
fn scan_first_returns_only_first() {
let mut table = reset_table(3);
table.set_transition(0, b'a', 1);
table.set_transition(1, b'a', 2);
table.set_transition(2, b'a', 2);
table.add_accept(1, 0);
table.add_accept(2, 1);
table.set_pattern_length(0, 1);
table.set_pattern_length(1, 2);
let jit = JitDfa::compile(&table).unwrap();
let first = jit.scan_first(b"aaa");
assert!(first.is_some());
assert_eq!(first.unwrap().start, 0);
}
#[test]
fn has_match_true_and_false() {
let mut table = reset_table(2);
table.set_transition(0, b'x', 1);
table.add_accept(1, 0);
table.set_pattern_length(0, 1);
let jit = JitDfa::compile(&table).unwrap();
assert!(jit.has_match(b"xxx"));
assert!(!jit.has_match(b"yyy"));
assert!(!jit.has_match(b""));
}
#[test]
fn minimize_preserves_multi_pattern_behavior() {
let mut table = reset_table(4);
table.set_transition(0, b'a', 1);
table.set_transition(0, b'b', 2);
table.set_transition(1, b'a', 1);
table.set_transition(2, b'b', 2);
table.add_accept(1, 0);
table.add_accept(2, 1);
table.set_pattern_length(0, 1);
table.set_pattern_length(1, 1);
let original_jit = JitDfa::compile(&table).unwrap();
let mut orig_matches = Vec::new();
original_jit.scan(b"ab", &mut orig_matches);
if let Some(minimized) = table.minimize() {
let min_jit = JitDfa::compile(&minimized).unwrap();
let mut min_matches = Vec::new();
min_jit.scan(b"ab", &mut min_matches);
assert_eq!(orig_matches.len(), min_matches.len());
}
}
#[test]
fn serialization_preserves_jit_behavior() {
let mut table = reset_table(4);
table.set_transition(0, b'f', 1);
table.set_transition(1, b'o', 2);
table.set_transition(2, b'o', 3);
table.add_accept(3, 0);
table.set_pattern_length(0, 3);
let bytes = table.to_bytes();
let restored = TransitionTable::from_bytes(&bytes).unwrap();
let orig_jit = JitDfa::compile(&table).unwrap();
let restored_jit = JitDfa::compile(&restored).unwrap();
let input = b"foobar";
let mut orig_matches = vec![Match::from_parts(0, 0, 0); 10];
let mut restored_matches = vec![Match::from_parts(0, 0, 0); 10];
let orig_count = orig_jit.scan(input, &mut orig_matches);
let restored_count = restored_jit.scan(input, &mut restored_matches);
assert_eq!(orig_count, restored_count);
assert_eq!(orig_matches[0].start, restored_matches[0].start);
assert_eq!(orig_matches[0].end, restored_matches[0].end);
}
#[test]
fn transition_density_computed_correctly() {
let mut table = reset_table(3);
table.set_transition(0, b'a', 1);
table.set_transition(0, b'b', 2);
table.set_transition(1, b'c', 1);
assert_eq!(table.transition_density(0), 3); assert_eq!(table.transition_density(1), 2); assert_eq!(table.transition_density(2), 1); }
#[test]
fn estimated_code_size_reasonable() {
let small_table = reset_table(3);
let medium_table = reset_table(100);
let large_table = reset_table(1000);
let small_size = small_table.estimated_code_size();
let medium_size = medium_table.estimated_code_size();
let large_size = large_table.estimated_code_size();
assert!(small_size < medium_size);
assert!(medium_size < large_size);
}
#[test]
fn jit_eligibility_boundary() {
let eligible = TransitionTable::new(4096, 256).unwrap();
let not_eligible = TransitionTable::new(4097, 256).unwrap();
assert!(eligible.is_jit_eligible());
assert!(!not_eligible.is_jit_eligible());
}
#[test]
fn pattern_length_zero_variable_width() {
let mut table = reset_table(2);
table.set_transition(0, b'x', 1);
table.add_accept(1, 0);
let jit = JitDfa::compile(&table).unwrap();
let mut matches = vec![Match::from_parts(0, 0, 0); 10];
let count = jit.scan(b"x", &mut matches);
assert_eq!(count, 1);
assert_eq!(matches[0].start, matches[0].end);
}