use crate::folds;
use crate::insn::{MAX_BYTE_SET_LENGTH, MAX_CHAR_SET_LENGTH};
use crate::ir::*;
use crate::types::BracketContents;
const LOOP_UNROLL_THRESHOLD: usize = 5;
pub enum PassAction {
Keep,
Modified,
Remove,
Replace(Node),
}
#[derive(Debug)]
struct Pass<'a, F>
where
F: FnMut(&mut Node, &Walk) -> PassAction,
{
func: &'a mut F,
changed: bool,
}
impl<'a, F> Pass<'a, F>
where
F: FnMut(&mut Node, &Walk) -> PassAction,
{
fn new(func: &'a mut F) -> Self {
Pass {
func,
changed: false,
}
}
fn run_postorder(&mut self, start: &mut Node) {
walk_mut(
true,
start,
&mut |n: &mut Node, walk: &mut Walk| match (self.func)(n, walk) {
PassAction::Keep => {}
PassAction::Modified => {
self.changed = true;
}
PassAction::Remove => {
*n = Node::Empty;
self.changed = true;
}
PassAction::Replace(newnode) => {
*n = newnode;
self.changed = true;
}
},
)
}
fn run_to_fixpoint(&mut self, n: &mut Node) {
debug_assert!(!self.changed, "Pass has already been run");
loop {
self.changed = false;
self.run_postorder(n);
if !self.changed {
break;
}
}
}
}
fn run_pass<F>(r: &mut Regex, func: &mut F) -> bool
where
F: FnMut(&mut Node, &Walk) -> PassAction,
{
let mut p = Pass::new(func);
p.run_to_fixpoint(&mut r.node);
p.changed
}
fn remove_empties(n: &mut Node, _w: &Walk) -> PassAction {
match n {
Node::Empty => PassAction::Keep,
Node::Goal => PassAction::Keep,
Node::Char { .. } => PassAction::Keep,
Node::ByteSequence(v) => {
if v.is_empty() {
PassAction::Remove
} else {
PassAction::Keep
}
}
Node::ByteSet(..) => PassAction::Keep,
Node::CharSet(..) => PassAction::Keep,
Node::Cat(nodes) => {
let blen = nodes.len();
nodes.retain(|nn| !nn.is_empty());
if nodes.len() == blen {
PassAction::Keep
} else {
match nodes.len() {
0 => PassAction::Remove,
1 => PassAction::Replace(nodes.pop().unwrap()),
_ => PassAction::Modified,
}
}
}
Node::Alt(left, right) => {
if left.is_empty() && right.is_empty() {
PassAction::Remove
} else {
PassAction::Keep
}
}
Node::MatchAny => PassAction::Keep,
Node::MatchAnyExceptLineTerminator => PassAction::Keep,
Node::Anchor { .. } => PassAction::Keep,
Node::Loop {
quant,
loopee,
enclosed_groups,
} => {
if loopee.is_empty() || (quant.max == 0 && enclosed_groups.start == enclosed_groups.end)
{
PassAction::Remove
} else {
PassAction::Keep
}
}
Node::Loop1CharBody { .. } => PassAction::Keep,
Node::CaptureGroup(..) => {
PassAction::Keep
}
Node::WordBoundary { .. } => PassAction::Keep,
Node::BackRef { .. } => PassAction::Keep,
Node::Bracket { .. } => PassAction::Keep,
Node::LookaroundAssertion {
negate, contents, ..
} => {
if !*negate && contents.is_empty() {
PassAction::Remove
} else {
PassAction::Keep
}
}
}
}
fn decat(n: &mut Node, _w: &Walk) -> PassAction {
match n {
Node::Cat(nodes) => {
if nodes.is_empty() {
PassAction::Remove
} else if nodes.len() == 1 {
PassAction::Replace(nodes.pop().unwrap())
} else if nodes.iter().any(|nn| nn.is_cat()) {
let mut catted = Vec::new();
std::mem::swap(nodes, &mut catted);
let mut decatted = Vec::new();
for nn in catted {
match nn {
Node::Cat(mut nnodes) => {
decatted.append(&mut nnodes);
}
_ => decatted.push(nn),
}
}
PassAction::Replace(Node::Cat(decatted))
} else {
PassAction::Keep
}
}
_ => PassAction::Keep,
}
}
fn unfold_icase_chars(n: &mut Node, _w: &Walk) -> PassAction {
match *n {
Node::Char { c, icase } if icase => {
let unfolded = folds::unfold_char(c);
debug_assert!(
unfolded.contains(&c),
"Char should always unfold to at least itself"
);
match unfolded.len() {
0 => panic!("Char should always unfold to at least itself"),
1 => {
PassAction::Replace(Node::Char { c, icase: false })
}
2..=MAX_BYTE_SET_LENGTH => {
PassAction::Replace(Node::CharSet(unfolded))
}
_ => panic!("Unfolded to more characters than we believed possible"),
}
}
_ => PassAction::Keep,
}
}
fn unroll_loops(n: &mut Node, _w: &Walk) -> PassAction {
match n {
Node::Loop {
loopee,
quant,
enclosed_groups,
} => {
if enclosed_groups.start < enclosed_groups.end {
return PassAction::Keep;
}
if quant.min == 0 || quant.min > LOOP_UNROLL_THRESHOLD {
return PassAction::Keep;
}
let mut unrolled = Vec::new();
for _ in 0..quant.min {
unrolled.push(loopee.as_mut().duplicate());
}
quant.max -= quant.min;
quant.min = 0;
if quant.max > 0 {
let mut loop_node = Node::Empty;
std::mem::swap(&mut loop_node, n);
unrolled.push(loop_node);
}
*n = Node::Cat(unrolled);
PassAction::Modified
}
_ => PassAction::Keep,
}
}
fn promote_1char_loops(n: &mut Node, _w: &Walk) -> PassAction {
match n {
Node::Loop {
loopee,
quant,
enclosed_groups,
} => {
if !loopee.matches_exactly_one_char() {
return PassAction::Keep;
}
assert!(
enclosed_groups.start >= enclosed_groups.end,
"Should have no enclosed groups"
);
let mut new_loopee = Box::new(Node::Empty);
std::mem::swap(&mut new_loopee, loopee);
*n = Node::Loop1CharBody {
loopee: new_loopee,
quant: *quant,
};
PassAction::Modified
}
_ => PassAction::Keep,
}
}
fn form_literal_bytes(n: &mut Node, walk: &Walk) -> PassAction {
fn get_literal_bytes(n: &mut Node) -> Option<&mut Vec<u8>> {
match n {
Node::ByteSequence(v) => Some(v),
_ => None,
}
}
match n {
Node::Char { c, icase } if !*icase => {
let mut buff = [0; 4];
PassAction::Replace(Node::ByteSequence(
c.encode_utf8(&mut buff).as_bytes().to_vec(),
))
}
Node::CharSet(chars) if chars.iter().all(char::is_ascii) => {
PassAction::Replace(Node::ByteSet(
chars.iter().map(|&c| c as u32 as u8).collect(),
))
}
Node::Cat(nodes) => {
let mut modified = false;
for idx in 1..nodes.len() {
let (prev_slice, curr_slice) = nodes.split_at_mut(idx);
match (
get_literal_bytes(prev_slice.last_mut().unwrap()),
get_literal_bytes(curr_slice.first_mut().unwrap()),
) {
(Some(prev_bytes), Some(curr_bytes))
if !prev_bytes.is_empty() && !curr_bytes.is_empty() =>
{
if walk.in_lookbehind {
curr_bytes.append(prev_bytes);
} else {
prev_bytes.append(curr_bytes);
std::mem::swap(prev_bytes, curr_bytes);
}
modified = true;
}
_ => (),
}
}
if modified {
PassAction::Modified
} else {
PassAction::Keep
}
}
_ => PassAction::Keep,
}
}
fn try_reduce_bracket(bc: &BracketContents) -> Option<Node> {
if bc.invert {
return None;
}
let mut cps_count = 0;
for iv in bc.cps.intervals() {
cps_count += iv.count_codepoints();
}
if cps_count > MAX_CHAR_SET_LENGTH {
return None;
}
let mut res = Vec::new();
for iv in bc.cps.intervals() {
for cp in iv.codepoints() {
res.push(std::char::from_u32(cp)?);
}
}
debug_assert!(res.len() <= MAX_CHAR_SET_LENGTH, "Unexpectedly many chars");
Some(Node::CharSet(res))
}
fn simplify_brackets(n: &mut Node, _walk: &Walk) -> PassAction {
match n {
Node::Bracket(bc) => {
if let Some(new_node) = try_reduce_bracket(bc) {
return PassAction::Replace(new_node);
}
if bc.cps.intervals().len() > bc.cps.inverted_interval_count() {
bc.cps = bc.cps.inverted();
bc.invert = !bc.invert;
PassAction::Modified
} else {
PassAction::Keep
}
}
_ => PassAction::Keep,
}
}
pub fn optimize(r: &mut Regex) {
run_pass(r, &mut simplify_brackets);
loop {
let mut changed = false;
changed |= run_pass(r, &mut decat);
if r.flags.icase {
changed |= run_pass(r, &mut unfold_icase_chars);
}
changed |= run_pass(r, &mut unroll_loops);
changed |= run_pass(r, &mut promote_1char_loops);
changed |= run_pass(r, &mut form_literal_bytes);
changed |= run_pass(r, &mut remove_empties);
if !changed {
break;
}
}
}