use crate::insn::{MAX_BYTE_SET_LENGTH, MAX_CHAR_SET_LENGTH};
use crate::ir::*;
use crate::types::BracketContents;
use crate::unicode;
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, vec::Vec};
const LOOP_UNROLL_THRESHOLD: usize = 5;
pub enum PassAction {
Keep,
Modified,
Remove,
Replace(Node),
}
#[derive(Debug)]
struct Pass<'a, F>
where
F: FnMut(&mut Node, &Walk) -> PassAction,
{
func: &'a mut F,
changed: bool,
unicode: bool,
}
impl<'a, F> Pass<'a, F>
where
F: FnMut(&mut Node, &Walk) -> PassAction,
{
fn new(func: &'a mut F, unicode: bool) -> Self {
Pass {
func,
changed: false,
unicode,
}
}
fn run_postorder(&mut self, start: &mut Node) {
walk_mut(
true,
self.unicode,
start,
&mut |n: &mut Node, walk: &mut Walk| match (self.func)(n, walk) {
PassAction::Keep => {}
PassAction::Modified => {
self.changed = true;
}
PassAction::Remove => {
*n = Node::Empty;
self.changed = true;
}
PassAction::Replace(newnode) => {
*n = newnode;
self.changed = true;
}
},
)
}
fn run_to_fixpoint(&mut self, n: &mut Node) {
debug_assert!(!self.changed, "Pass has already been run");
loop {
self.changed = false;
self.run_postorder(n);
if !self.changed {
break;
}
}
}
}
fn run_pass<F>(r: &mut Regex, func: &mut F) -> bool
where
F: FnMut(&mut Node, &Walk) -> PassAction,
{
let mut p = Pass::new(func, r.flags.unicode);
p.run_to_fixpoint(&mut r.node);
p.changed
}
fn remove_empties(n: &mut Node, _w: &Walk) -> PassAction {
match n {
Node::Empty | Node::Goal | Node::Char { .. } => PassAction::Keep,
Node::ByteSequence(v) => {
if v.is_empty() {
PassAction::Remove
} else {
PassAction::Keep
}
}
Node::ByteSet(..) | Node::CharSet(..) => PassAction::Keep,
Node::Cat(nodes) => {
let blen = nodes.len();
nodes.retain(|nn| !nn.is_empty());
if nodes.len() == blen {
PassAction::Keep
} else {
match nodes.len() {
0 => PassAction::Remove,
1 => PassAction::Replace(nodes.pop().unwrap()),
_ => PassAction::Modified,
}
}
}
Node::Alt(left, right) => {
if left.is_empty() && right.is_empty() {
PassAction::Remove
} else {
PassAction::Keep
}
}
Node::MatchAny | Node::MatchAnyExceptLineTerminator | Node::Anchor { .. } => {
PassAction::Keep
}
Node::Loop {
quant,
loopee,
enclosed_groups,
} => {
if loopee.is_empty()
|| (quant.max == Some(0) && enclosed_groups.start == enclosed_groups.end)
{
PassAction::Remove
} else {
PassAction::Keep
}
}
Node::Loop1CharBody { .. } => PassAction::Keep,
Node::CaptureGroup { .. } => {
PassAction::Keep
}
Node::WordBoundary { .. } | Node::BackRef { .. } | Node::Bracket { .. } => PassAction::Keep,
Node::LookaroundAssertion {
negate, contents, ..
} => {
if !*negate && contents.is_empty() {
PassAction::Remove
} else {
PassAction::Keep
}
}
}
}
fn contains_capture_groups(node: &Node) -> bool {
match node {
Node::CaptureGroup { .. } => true,
Node::Cat(nodes) => nodes.iter().any(contains_capture_groups),
Node::Alt(left, right) => contains_capture_groups(left) || contains_capture_groups(right),
Node::Loop { loopee, .. } => contains_capture_groups(loopee),
Node::LookaroundAssertion { contents, .. } => contains_capture_groups(contents),
_ => false,
}
}
fn propagate_early_fails(n: &mut Node, _w: &Walk) -> PassAction {
if contains_capture_groups(n) {
return PassAction::Keep;
}
match n {
Node::Cat(nodes) => {
if nodes.iter().any(|nn| nn.match_always_fails()) {
PassAction::Replace(Node::make_always_fails())
} else {
PassAction::Keep
}
}
Node::Alt(left, right) => {
let left_fails = left.match_always_fails();
let right_fails = right.match_always_fails();
match (left_fails, right_fails) {
(true, true) => PassAction::Replace(Node::make_always_fails()),
(false, false) => PassAction::Keep,
(true, false) | (false, true) => {
let mut new_node = Node::Empty;
core::mem::swap(
&mut new_node,
if left_fails { &mut *right } else { &mut *left },
);
PassAction::Replace(new_node)
}
}
}
Node::Loop {
loopee,
quant,
enclosed_groups,
} => {
if enclosed_groups.start < enclosed_groups.end {
return PassAction::Keep;
}
if quant.min > 0 && loopee.match_always_fails() {
PassAction::Replace(Node::make_always_fails())
} else {
PassAction::Keep
}
}
_ => PassAction::Keep,
}
}
fn decat(n: &mut Node, _w: &Walk) -> PassAction {
match n {
Node::Cat(nodes) => {
if nodes.is_empty() {
PassAction::Remove
} else if nodes.len() == 1 {
PassAction::Replace(nodes.pop().unwrap())
} else if nodes.iter().any(|nn| nn.is_cat()) {
let mut catted = Vec::new();
core::mem::swap(nodes, &mut catted);
let mut decatted = Vec::new();
for nn in catted {
match nn {
Node::Cat(mut nnodes) => {
decatted.append(&mut nnodes);
}
_ => decatted.push(nn),
}
}
PassAction::Replace(Node::Cat(decatted))
} else {
PassAction::Keep
}
}
_ => PassAction::Keep,
}
}
fn unfold_icase_chars(n: &mut Node, w: &Walk) -> PassAction {
match *n {
Node::Char { c, icase } if icase && !w.unicode => {
let unfolded = unicode::unfold_uppercase_char(c);
debug_assert!(
unfolded.contains(&c),
"Char should always unfold to at least itself"
);
match unfolded.len() {
0 => panic!("Char should always unfold to at least itself"),
1 => {
PassAction::Replace(Node::Char { c, icase: false })
}
2..=MAX_BYTE_SET_LENGTH => {
PassAction::Replace(Node::CharSet(unfolded))
}
_ => panic!("Unfolded to more characters than we believed possible"),
}
}
Node::Char { c, icase } if icase => {
let unfolded = unicode::unfold_char(c);
debug_assert!(
unfolded.contains(&c),
"Char should always unfold to at least itself"
);
match unfolded.len() {
0 => panic!("Char should always unfold to at least itself"),
1 => {
PassAction::Replace(Node::Char { c, icase: false })
}
2..=MAX_BYTE_SET_LENGTH => {
PassAction::Replace(Node::CharSet(unfolded))
}
_ => panic!("Unfolded to more characters than we believed possible"),
}
}
_ => PassAction::Keep,
}
}
fn unroll_loops(n: &mut Node, _w: &Walk) -> PassAction {
match n {
Node::Loop {
loopee,
quant,
enclosed_groups,
} => {
if enclosed_groups.start < enclosed_groups.end {
return PassAction::Keep;
}
if quant.min == 0 || quant.min > LOOP_UNROLL_THRESHOLD {
return PassAction::Keep;
}
let mut unrolled = Vec::new();
for _ in 0..quant.min {
let Some(node) = loopee.try_duplicate(0) else {
return PassAction::Keep;
};
unrolled.push(node);
}
quant.max = quant.max.map(|v| v - quant.min);
quant.min = 0;
if quant.max != Some(0) {
let mut loop_node = Node::Empty;
core::mem::swap(&mut loop_node, n);
unrolled.push(loop_node);
}
*n = Node::Cat(unrolled);
PassAction::Modified
}
_ => PassAction::Keep,
}
}
fn promote_1char_loops(n: &mut Node, _w: &Walk) -> PassAction {
match n {
Node::Loop {
loopee,
quant,
enclosed_groups,
} => {
if !loopee.matches_exactly_one_char() {
return PassAction::Keep;
}
assert!(
enclosed_groups.start >= enclosed_groups.end,
"Should have no enclosed groups"
);
let mut new_loopee = Box::new(Node::Empty);
core::mem::swap(&mut new_loopee, loopee);
*n = Node::Loop1CharBody {
loopee: new_loopee,
quant: *quant,
};
PassAction::Modified
}
_ => PassAction::Keep,
}
}
#[cfg(not(feature = "utf16"))]
fn form_literal_bytes(n: &mut Node, walk: &Walk) -> PassAction {
fn get_literal_bytes(n: &mut Node) -> Option<&mut Vec<u8>> {
match n {
Node::ByteSequence(v) => Some(v),
_ => None,
}
}
match n {
Node::Char { c, icase } if !*icase => {
if let Some(c) = char::from_u32(*c) {
let mut buff = [0; 4];
PassAction::Replace(Node::ByteSequence(
c.encode_utf8(&mut buff).as_bytes().to_vec(),
))
} else {
PassAction::Keep
}
}
Node::CharSet(chars) if chars.iter().all(|&c| c <= 0x7F) => {
PassAction::Replace(Node::ByteSet(chars.iter().map(|&c| c as u8).collect()))
}
Node::Cat(nodes) => {
let mut modified = false;
for idx in 1..nodes.len() {
let (prev_slice, curr_slice) = nodes.split_at_mut(idx);
match (
get_literal_bytes(prev_slice.last_mut().unwrap()),
get_literal_bytes(curr_slice.first_mut().unwrap()),
) {
(Some(prev_bytes), Some(curr_bytes))
if !prev_bytes.is_empty() && !curr_bytes.is_empty() =>
{
if walk.in_lookbehind {
curr_bytes.append(prev_bytes);
} else {
prev_bytes.append(curr_bytes);
core::mem::swap(prev_bytes, curr_bytes);
}
modified = true;
}
_ => (),
}
}
if modified {
PassAction::Modified
} else {
PassAction::Keep
}
}
_ => PassAction::Keep,
}
}
fn try_reduce_bracket(bc: &BracketContents) -> Option<Node> {
if bc.invert {
return None;
}
let mut cps_count = 0;
for iv in bc.cps.intervals() {
cps_count += iv.count_codepoints();
}
if cps_count > MAX_CHAR_SET_LENGTH {
return None;
}
let mut res = Vec::new();
for iv in bc.cps.intervals() {
for cp in iv.codepoints() {
res.push(cp);
}
}
debug_assert!(res.len() <= MAX_CHAR_SET_LENGTH, "Unexpectedly many chars");
Some(Node::CharSet(res))
}
fn simplify_brackets(n: &mut Node, _walk: &Walk) -> PassAction {
match n {
Node::Bracket(bc) => {
if let Some(new_node) = try_reduce_bracket(bc) {
return PassAction::Replace(new_node);
}
if bc.cps.intervals().len() > bc.cps.inverted_interval_count() {
bc.cps = bc.cps.inverted();
bc.invert = !bc.invert;
PassAction::Modified
} else {
PassAction::Keep
}
}
_ => PassAction::Keep,
}
}
pub fn optimize(r: &mut Regex) {
run_pass(r, &mut simplify_brackets);
loop {
let mut changed = false;
changed |= run_pass(r, &mut decat);
if r.flags.icase {
changed |= run_pass(r, &mut unfold_icase_chars);
}
changed |= run_pass(r, &mut unroll_loops);
changed |= run_pass(r, &mut promote_1char_loops);
#[cfg(not(feature = "utf16"))]
{
changed |= run_pass(r, &mut form_literal_bytes);
}
changed |= run_pass(r, &mut remove_empties);
changed |= run_pass(r, &mut propagate_early_fails);
if !changed {
break;
}
}
}