use regex_syntax::hir::{Hir, HirKind};
use crate::query::GramQuery;
use crate::tokenizer::build_covering_inner;
const MAX_LITERAL_VARIANTS: usize = 16;
pub fn decompose(pattern: &str, case_insensitive: bool) -> Result<GramQuery, String> {
let hir = regex_syntax::ParserBuilder::new()
.case_insensitive(case_insensitive)
.utf8(false)
.build()
.parse(pattern)
.map_err(|e| e.to_string())?;
Ok(walk(&hir))
}
fn walk(hir: &Hir) -> GramQuery {
if let Some(query) = exact_literal_query(hir) {
return query;
}
match hir.kind() {
HirKind::Literal(lit) => {
match build_covering_inner(&lit.0) {
Some(grams) if !grams.is_empty() => GramQuery::Grams(grams),
_ => GramQuery::All,
}
}
HirKind::Concat(subs) => {
GramQuery::And(subs.iter().map(walk).collect())
}
HirKind::Alternation(subs) => {
GramQuery::Or(subs.iter().map(walk).collect())
}
HirKind::Repetition(rep) => {
if rep.min >= 1 {
walk(&rep.sub)
} else {
GramQuery::All
}
}
HirKind::Capture(cap) => {
walk(&cap.sub)
}
HirKind::Class(_) | HirKind::Look(_) | HirKind::Empty => GramQuery::All,
}
}
fn exact_literal_query(hir: &Hir) -> Option<GramQuery> {
let literals = exact_literal_variants(hir, MAX_LITERAL_VARIANTS)?;
let mut branches = Vec::with_capacity(literals.len());
for literal in literals {
let grams = build_covering_inner(&literal)?;
if grams.is_empty() {
return None;
}
branches.push(GramQuery::Grams(grams));
}
match branches.len() {
0 => None,
1 => branches.into_iter().next(),
_ => Some(GramQuery::Or(branches)),
}
}
fn exact_literal_variants(hir: &Hir, limit: usize) -> Option<Vec<Vec<u8>>> {
match hir.kind() {
HirKind::Empty => Some(vec![Vec::new()]),
HirKind::Literal(lit) => Some(vec![lit.0.to_vec()]),
HirKind::Capture(cap) => exact_literal_variants(&cap.sub, limit),
HirKind::Concat(subs) => {
let mut acc = vec![Vec::new()];
for sub in subs {
let variants = exact_literal_variants(sub, limit)?;
acc = concat_variants(acc, variants, limit)?;
}
Some(acc)
}
HirKind::Alternation(subs) => {
let mut acc = Vec::new();
for sub in subs {
let variants = exact_literal_variants(sub, limit)?;
if acc.len() + variants.len() > limit {
return None;
}
acc.extend(variants);
}
Some(acc)
}
HirKind::Repetition(rep) => match rep.max {
Some(max) if max == rep.min && max <= 4 => {
let variants = exact_literal_variants(&rep.sub, limit)?;
let mut acc = vec![Vec::new()];
for _ in 0..rep.min {
acc = concat_variants(acc, variants.clone(), limit)?;
}
Some(acc)
}
_ => None,
},
HirKind::Class(_) | HirKind::Look(_) => None,
}
}
fn concat_variants(left: Vec<Vec<u8>>, right: Vec<Vec<u8>>, limit: usize) -> Option<Vec<Vec<u8>>> {
if left.is_empty() || right.is_empty() {
return Some(Vec::new());
}
let total = left.len().checked_mul(right.len())?;
if total > limit {
return None;
}
let mut combined = Vec::with_capacity(total);
for prefix in left {
for suffix in &right {
let mut literal = prefix.clone();
literal.extend_from_slice(suffix);
combined.push(literal);
}
}
Some(combined)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decompose_accepts_patterns_that_can_match_invalid_utf8() {
let query = decompose(r"(?-u)\xFFneedle\x80", false).expect("byte regex should parse");
assert!(!matches!(query, GramQuery::None));
}
#[test]
fn optional_prefix_does_not_force_required_grams() {
let optional = decompose(r"(foo)?needle", false)
.expect("regex should parse")
.simplify();
let plain = decompose(r"needle", false)
.expect("regex should parse")
.simplify();
assert_eq!(format!("{optional:?}"), format!("{plain:?}"));
}
}