use super::{Hir, HirCapture, HirExpr};
use std::collections::HashMap;
const MIN_LITERALS_FOR_OPTIMIZATION: usize = 4;
#[derive(Debug, Default)]
struct TrieNode {
children: HashMap<u8, TrieNode>,
is_terminal: bool,
capture_index: Option<u32>,
capture_name: Option<String>,
}
impl TrieNode {
fn new() -> Self {
Self::default()
}
fn insert(&mut self, bytes: &[u8], capture_index: Option<u32>, capture_name: Option<String>) {
let mut node = self;
for &byte in bytes {
node = node.children.entry(byte).or_default();
}
node.is_terminal = true;
node.capture_index = capture_index;
node.capture_name = capture_name;
}
fn to_hir(&self) -> HirExpr {
if self.children.is_empty() {
return HirExpr::Empty;
}
let mut children: Vec<(u8, &TrieNode)> =
self.children.iter().map(|(&b, n)| (b, n)).collect();
children.sort_by_key(|(b, _)| *b);
if children.len() == 1 {
let (byte, child) = children[0];
let child_hir = child.to_hir();
let literal = HirExpr::Literal(vec![byte]);
if child.is_terminal && !child.children.is_empty() {
let child_expr = child.to_hir_with_optional_suffix();
return HirExpr::Concat(vec![literal, child_expr]);
}
match child_hir {
HirExpr::Empty => literal,
HirExpr::Concat(mut parts) => {
parts.insert(0, literal);
HirExpr::Concat(parts)
}
other => HirExpr::Concat(vec![literal, other]),
}
} else {
let alts: Vec<HirExpr> = children
.iter()
.map(|(byte, child)| {
let literal = HirExpr::Literal(vec![*byte]);
let child_hir = if child.is_terminal && !child.children.is_empty() {
child.to_hir_with_optional_suffix()
} else {
child.to_hir()
};
match child_hir {
HirExpr::Empty => literal,
HirExpr::Concat(mut parts) => {
parts.insert(0, literal);
HirExpr::Concat(parts)
}
other => HirExpr::Concat(vec![literal, other]),
}
})
.collect();
if alts.len() == 1 {
alts.into_iter().next().unwrap()
} else {
HirExpr::Alt(alts)
}
}
}
fn to_hir_with_optional_suffix(&self) -> HirExpr {
if self.children.is_empty() {
return HirExpr::Empty;
}
let suffix = self.to_hir();
HirExpr::Alt(vec![HirExpr::Empty, suffix])
}
}
pub fn optimize_prefixes(hir: Hir) -> Hir {
let expr = optimize_expr(hir.expr);
Hir {
expr,
props: hir.props,
}
}
fn optimize_expr(expr: HirExpr) -> HirExpr {
match expr {
HirExpr::Alt(variants) => optimize_alternation(variants),
HirExpr::Concat(parts) => {
let optimized: Vec<HirExpr> = parts.into_iter().map(optimize_expr).collect();
HirExpr::Concat(optimized)
}
HirExpr::Repeat(rep) => HirExpr::Repeat(Box::new(super::HirRepeat {
expr: optimize_expr(rep.expr),
min: rep.min,
max: rep.max,
greedy: rep.greedy,
})),
HirExpr::Capture(cap) => HirExpr::Capture(Box::new(HirCapture {
index: cap.index,
name: cap.name,
expr: optimize_expr(cap.expr),
})),
HirExpr::Lookaround(la) => HirExpr::Lookaround(Box::new(super::HirLookaround {
expr: optimize_expr(la.expr),
kind: la.kind,
})),
other => other,
}
}
fn optimize_alternation(variants: Vec<HirExpr>) -> HirExpr {
let mut literals: Vec<(Vec<u8>, Option<u32>, Option<String>)> = Vec::new();
let mut complex: Vec<HirExpr> = Vec::new();
for variant in variants {
match extract_literal(&variant) {
Some((bytes, cap_idx, cap_name)) => {
literals.push((bytes, cap_idx, cap_name));
}
None => {
complex.push(optimize_expr(variant));
}
}
}
if literals.len() < MIN_LITERALS_FOR_OPTIMIZATION {
let mut result: Vec<HirExpr> = literals
.into_iter()
.map(|(bytes, cap_idx, cap_name)| {
let lit = HirExpr::Literal(bytes);
wrap_in_capture(lit, cap_idx, cap_name)
})
.collect();
result.extend(complex);
if result.len() == 1 {
return result.into_iter().next().unwrap();
}
return HirExpr::Alt(result);
}
let mut trie = TrieNode::new();
for (bytes, cap_idx, cap_name) in &literals {
trie.insert(bytes, *cap_idx, cap_name.clone());
}
let optimized_literals = trie.to_hir();
if complex.is_empty() {
optimized_literals
} else {
let mut result = vec![optimized_literals];
result.extend(complex);
HirExpr::Alt(result)
}
}
fn extract_literal(expr: &HirExpr) -> Option<(Vec<u8>, Option<u32>, Option<String>)> {
match expr {
HirExpr::Literal(bytes) => Some((bytes.clone(), None, None)),
HirExpr::Capture(cap) => {
if let HirExpr::Literal(bytes) = &cap.expr {
Some((bytes.clone(), Some(cap.index), cap.name.clone()))
} else {
None
}
}
HirExpr::Concat(parts) => {
let mut result = Vec::new();
for part in parts {
match part {
HirExpr::Literal(bytes) => result.extend(bytes),
_ => return None,
}
}
Some((result, None, None))
}
_ => None,
}
}
fn wrap_in_capture(expr: HirExpr, cap_idx: Option<u32>, cap_name: Option<String>) -> HirExpr {
match cap_idx {
Some(idx) => HirExpr::Capture(Box::new(HirCapture {
index: idx,
name: cap_name,
expr,
})),
None => expr,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hir::HirProps;
fn lit(s: &str) -> HirExpr {
HirExpr::Literal(s.as_bytes().to_vec())
}
fn alt(exprs: Vec<HirExpr>) -> HirExpr {
HirExpr::Alt(exprs)
}
#[test]
fn test_trie_simple() {
let mut trie = TrieNode::new();
trie.insert(b"the", None, None);
trie.insert(b"that", None, None);
trie.insert(b"them", None, None);
trie.insert(b"they", None, None);
let hir = trie.to_hir();
match hir {
HirExpr::Concat(parts) => {
assert!(matches!(&parts[0], HirExpr::Literal(b) if b == b"t"));
}
_ => panic!("Expected Concat, got {:?}", hir),
}
}
#[test]
fn test_optimize_small_alternation() {
let expr = alt(vec![lit("a"), lit("b"), lit("c")]);
let hir = Hir {
expr,
props: HirProps::default(),
};
let optimized = optimize_prefixes(hir);
assert!(matches!(optimized.expr, HirExpr::Alt(_)));
}
#[test]
fn test_optimize_large_alternation() {
let expr = alt(vec![lit("the"), lit("that"), lit("them"), lit("they")]);
let hir = Hir {
expr,
props: HirProps::default(),
};
let optimized = optimize_prefixes(hir);
match optimized.expr {
HirExpr::Concat(parts) => {
assert!(matches!(&parts[0], HirExpr::Literal(b) if b == b"t"));
}
_ => panic!("Expected optimized to Concat, got {:?}", optimized.expr),
}
}
#[test]
fn test_optimize_mixed() {
let expr = alt(vec![
lit("the"),
lit("that"),
lit("them"),
lit("they"),
HirExpr::Repeat(Box::new(super::super::HirRepeat {
expr: lit("x"),
min: 1,
max: None,
greedy: true,
})),
]);
let hir = Hir {
expr,
props: HirProps::default(),
};
let optimized = optimize_prefixes(hir);
assert!(matches!(optimized.expr, HirExpr::Alt(_)));
}
#[test]
fn test_no_common_prefix() {
let expr = alt(vec![
lit("apple"),
lit("banana"),
lit("cherry"),
lit("date"),
]);
let hir = Hir {
expr,
props: HirProps::default(),
};
let optimized = optimize_prefixes(hir);
assert!(matches!(optimized.expr, HirExpr::Alt(_)));
}
#[test]
fn test_partial_overlap() {
let expr = alt(vec![
lit("test"),
lit("testing"),
lit("tested"),
lit("tester"),
lit("apple"),
lit("application"),
]);
let hir = Hir {
expr,
props: HirProps::default(),
};
let optimized = optimize_prefixes(hir);
match optimized.expr {
HirExpr::Alt(branches) => {
assert_eq!(branches.len(), 2);
}
_ => panic!("Expected Alt with 2 branches, got {:?}", optimized.expr),
}
}
}