use crate::dfa::{CharClass, DfaStateId, LazyDfa};
use crate::error::{Error, ErrorKind, Result};
use dynasmrt::{AssemblyOffset, ExecutableBuffer};
const ARM64_DFA_JIT_ENABLED: bool = true;
pub struct CompiledRegex {
code: ExecutableBuffer,
entry_point: AssemblyOffset,
entry_point_word: Option<AssemblyOffset>,
pub(crate) has_word_boundary: bool,
match_needs_word_boundary: bool,
match_needs_not_word_boundary: bool,
pub(crate) has_anchors: bool,
pub(crate) has_start_anchor: bool,
#[allow(dead_code)]
pub(crate) has_end_anchor: bool,
pub(crate) has_multiline_anchors: bool,
pub(crate) match_needs_end_of_text: bool,
pub(crate) match_needs_end_of_line: bool,
}
impl CompiledRegex {
fn execute_with_class(&self, input: &[u8], prev_class: CharClass) -> Option<(usize, usize)> {
type MatchFn = unsafe extern "C" fn(*const u8, usize) -> i64;
let entry = if self.has_word_boundary && prev_class == CharClass::Word {
self.entry_point_word.unwrap_or(self.entry_point)
} else {
self.entry_point
};
let func: MatchFn = unsafe { std::mem::transmute(self.code.ptr(entry)) };
let result = unsafe { func(input.as_ptr(), input.len()) };
if result >= 0 {
let packed = result as u64;
let start_pos = (packed >> 32) as usize;
let end_pos = (packed & 0xFFFF_FFFF) as usize;
if !self.validate_end_assertions(input, start_pos, end_pos, prev_class) {
return None;
}
Some((start_pos, end_pos))
} else {
None
}
}
fn validate_end_assertions(
&self,
input: &[u8],
start_pos: usize,
end_pos: usize,
prev_class: CharClass,
) -> bool {
if self.has_word_boundary
&& (self.match_needs_word_boundary || self.match_needs_not_word_boundary)
{
let actual_prev_class = if start_pos > 0 {
CharClass::from_byte(input[start_pos - 1])
} else {
prev_class
};
let is_at_boundary = if end_pos == start_pos {
if end_pos < input.len() {
actual_prev_class != CharClass::from_byte(input[end_pos])
} else {
actual_prev_class != CharClass::NonWord
}
} else {
let last_class = CharClass::from_byte(input[end_pos - 1]);
let next_class = if end_pos < input.len() {
CharClass::from_byte(input[end_pos])
} else {
CharClass::NonWord
};
last_class != next_class
};
if self.match_needs_word_boundary && !is_at_boundary {
return false;
}
if self.match_needs_not_word_boundary && is_at_boundary {
return false;
}
}
if self.has_anchors {
if self.match_needs_end_of_text && end_pos != input.len() {
return false;
}
if self.match_needs_end_of_line {
let at_end_of_line = end_pos == input.len() || input.get(end_pos) == Some(&b'\n');
if !at_end_of_line {
return false;
}
}
}
true
}
pub fn execute(&self, input: &[u8]) -> Option<(usize, usize)> {
self.execute_with_class(input, CharClass::NonWord)
}
pub fn is_match(&self, input: &[u8]) -> bool {
self.find(input).is_some()
}
pub fn is_full_match(&self, input: &[u8]) -> bool {
match self.execute(input) {
Some((start, end)) => start == 0 && end == input.len(),
None => false,
}
}
pub fn find(&self, input: &[u8]) -> Option<(usize, usize)> {
if self.has_start_anchor {
if self.has_multiline_anchors {
if let Some((start, end)) = self.find_at(input, 0) {
return Some((start, end));
}
for (i, &byte) in input.iter().enumerate() {
if byte == b'\n' && i + 1 <= input.len() {
if let Some((start, end)) = self.find_at(input, i + 1) {
return Some((start, end));
}
}
}
None
} else {
self.find_at(input, 0)
}
} else {
self.execute(input)
}
}
pub fn find_at(&self, input: &[u8], start_pos: usize) -> Option<(usize, usize)> {
if start_pos > input.len() {
return None;
}
if self.has_start_anchor {
let valid_start = if self.has_multiline_anchors {
start_pos == 0 || (start_pos > 0 && input[start_pos - 1] == b'\n')
} else {
start_pos == 0
};
if !valid_start {
return None;
}
}
let prev_class = if self.has_word_boundary && start_pos > 0 {
CharClass::from_byte(input[start_pos - 1])
} else {
CharClass::NonWord
};
self.execute_with_class(&input[start_pos..], prev_class)
.map(|(rel_start, rel_end)| (start_pos + rel_start, start_pos + rel_end))
}
}
pub struct JitCompiler;
impl JitCompiler {
pub fn new() -> Self {
Self
}
pub fn compile_dfa(self, dfa: &mut LazyDfa) -> Result<CompiledRegex> {
if !ARM64_DFA_JIT_ENABLED {
return Err(Error::new(
ErrorKind::Jit("ARM64 DFA JIT temporarily disabled".to_string()),
"",
));
}
let materialized = self.materialize_dfa(dfa)?;
const MAX_ARM64_DFA_STATES: usize = 64;
if materialized.states.len() > MAX_ARM64_DFA_STATES {
return Err(Error::new(
ErrorKind::Jit(format!(
"DFA too large for ARM64 JIT ({} states, max {})",
materialized.states.len(),
MAX_ARM64_DFA_STATES
)),
"",
));
}
let (code, entry_point, entry_point_word) =
crate::jit::aarch64::compile_states(&materialized)?;
let mut match_needs_word_boundary = false;
let mut match_needs_not_word_boundary = false;
let mut match_needs_end_of_text = false;
let mut match_needs_end_of_line = false;
for state in &materialized.states {
if state.is_match {
match_needs_word_boundary |= state.needs_word_boundary;
match_needs_not_word_boundary |= state.needs_not_word_boundary;
match_needs_end_of_text |= state.needs_end_of_text;
match_needs_end_of_line |= state.needs_end_of_line;
}
}
Ok(CompiledRegex {
code,
entry_point,
entry_point_word,
has_word_boundary: materialized.has_word_boundary,
match_needs_word_boundary,
match_needs_not_word_boundary,
has_anchors: materialized.has_anchors,
has_start_anchor: materialized.has_start_anchor,
has_end_anchor: materialized.has_end_anchor,
has_multiline_anchors: materialized.has_multiline_anchors,
match_needs_end_of_text,
match_needs_end_of_line,
})
}
fn materialize_dfa(&self, dfa: &mut LazyDfa) -> Result<MaterializedDfa> {
let has_word_boundary = dfa.has_word_boundary();
let has_anchors = dfa.has_anchors();
let has_start_anchor = dfa.has_start_anchor();
let has_end_anchor = dfa.has_end_anchor();
let has_multiline_anchors = dfa.has_multiline_anchors();
let start_nonword = dfa.get_start_state_for_class(CharClass::NonWord);
let start_word = if has_word_boundary {
Some(dfa.get_start_state_for_class(CharClass::Word))
} else {
None
};
let mut materialized = MaterializedDfa {
states: Vec::new(),
start: start_nonword,
start_word,
has_word_boundary,
has_anchors,
has_start_anchor,
has_end_anchor,
has_multiline_anchors,
};
let mut queue = vec![start_nonword];
let mut visited = std::collections::HashSet::new();
visited.insert(start_nonword);
if let Some(sw) = start_word {
if visited.insert(sw) {
queue.push(sw);
}
}
while let Some(state_id) = queue.pop() {
let transitions = dfa.compute_all_transitions(state_id);
for byte in 0..=255u8 {
if let Some(next_state) = transitions[byte as usize] {
if visited.insert(next_state) {
queue.push(next_state);
}
}
}
let is_match = dfa.is_match(state_id);
let (needs_word_boundary, needs_not_word_boundary) =
dfa.get_state_boundary_requirements(state_id);
let (needs_end_of_text, needs_end_of_line) =
dfa.get_state_anchor_requirements(state_id);
materialized.states.push(MaterializedState {
id: state_id,
transitions,
is_match,
needs_word_boundary,
needs_not_word_boundary,
needs_end_of_text,
needs_end_of_line,
});
}
materialized.states.sort_by_key(|s| s.id);
Ok(materialized)
}
}
impl Default for JitCompiler {
fn default() -> Self {
Self::new()
}
}
pub struct MaterializedDfa {
pub states: Vec<MaterializedState>,
pub start: DfaStateId,
pub start_word: Option<DfaStateId>,
pub has_word_boundary: bool,
pub has_anchors: bool,
pub has_start_anchor: bool,
pub has_end_anchor: bool,
pub has_multiline_anchors: bool,
}
#[derive(Debug, Clone)]
pub struct MaterializedState {
pub id: DfaStateId,
pub transitions: [Option<DfaStateId>; 256],
pub is_match: bool,
pub needs_word_boundary: bool,
pub needs_not_word_boundary: bool,
pub needs_end_of_text: bool,
pub needs_end_of_line: bool,
}
impl MaterializedState {
pub fn transition_density(&self) -> usize {
self.transitions.iter().filter(|t| t.is_some()).count()
}
pub fn should_use_jump_table(&self) -> bool {
self.transition_density() > 10
}
pub fn transition_ranges(&self) -> Vec<(u8, u8, DfaStateId)> {
let mut ranges = Vec::new();
let mut current_target = None;
let mut range_start = 0u8;
for byte in 0..=255u8 {
let target = self.transitions[byte as usize];
match (current_target, target) {
(None, Some(t)) => {
current_target = Some(t);
range_start = byte;
}
(Some(curr), Some(t)) if curr == t => {}
(Some(curr), _) => {
ranges.push((range_start, byte - 1, curr));
current_target = target;
range_start = byte;
}
(None, None) => {}
}
if byte == 255 {
if let Some(t) = current_target {
ranges.push((range_start, byte, t));
}
}
}
ranges
}
}