use aho_corasick::{AhoCorasick, MatchKind};
use regex_syntax::hir::literal::{ExtractKind, Extractor};
use regex_syntax::hir::{Class, Hir, HirKind, Literal, Look};
use super::plan::{Anchor, Plan, ShapeOp, SmallByteSet};
use super::{BuildError, MatcherTier};
const CLASS_BYTE_CAP: usize = 16;
const LITERAL_SET_CAP: usize = 2048;
pub struct Classified {
pub plan: Plan,
pub descriptor: Descriptor,
}
#[derive(Debug, Clone)]
pub struct Descriptor {
pub tier: MatcherTier,
pub reason: &'static str,
pub hint: &'static str,
}
pub fn classify(pattern: &str, case_insensitive: bool) -> Result<Classified, BuildError> {
if pattern.is_empty() {
return Err(BuildError::Empty);
}
let hir = regex_syntax::parse(pattern).map_err(|e| BuildError::Syntax {
pattern: pattern.to_string(),
source: Box::new(e),
hint: SYNTAX_HINT,
})?;
if let Some((reason, hint)) = hard_reject(&hir) {
return Ok(meta_plan(pattern, reason, hint));
}
if !case_insensitive && let Some(shape) = try_simple_shape(&hir) {
let descriptor = Descriptor {
tier: tier_for_shape(&shape),
reason: shape_reason(&shape),
hint: "",
};
return Ok(Classified {
plan: Plan::Shape(shape),
descriptor,
});
}
if let Some((ac, anchor, literals)) = try_literal_alternation(&hir, case_insensitive) {
let descriptor = Descriptor {
tier: MatcherTier::LiteralSet,
reason: "literal-only:alternation",
hint: "",
};
return Ok(Classified {
plan: Plan::LiteralOnly {
ac: Box::new(ac),
anchor,
literals: literals_to_boxed(literals),
case_insensitive,
},
descriptor,
});
}
if case_insensitive
&& let Some((ac, anchor, literals)) = try_single_literal_for_case_insensitive(&hir)
{
let descriptor = Descriptor {
tier: MatcherTier::LiteralSet,
reason: "literal-only:case-insensitive",
hint: "",
};
return Ok(Classified {
plan: Plan::LiteralOnly {
ac: Box::new(ac),
anchor,
literals: literals_to_boxed(literals),
case_insensitive,
},
descriptor,
});
}
if let Some((ac, anchor, literals)) = try_extractor_fallback(&hir, case_insensitive) {
let descriptor = Descriptor {
tier: MatcherTier::LiteralSet,
reason: "literal-only:extracted",
hint: "",
};
return Ok(Classified {
plan: Plan::LiteralOnly {
ac: Box::new(ac),
anchor,
literals: literals_to_boxed(literals),
case_insensitive,
},
descriptor,
});
}
Ok(meta_plan(
pattern,
"complex-shape",
"the pattern combines anchors, classes, and quantifiers in a way that \
prevents safe literal reduction; consider splitting into smaller \
patterns or accepting Meta-tier cost",
))
}
const SYNTAX_HINT: &str = "check the regex syntax and ensure it parses under \
the regex-syntax crate's default flags";
fn hard_reject(hir: &Hir) -> Option<(&'static str, &'static str)> {
use HirKind as K;
match hir.kind() {
K::Look(l) => match l {
Look::Start | Look::End => None,
Look::WordAscii
| Look::WordAsciiNegate
| Look::WordUnicode
| Look::WordUnicodeNegate
| Look::WordStartAscii
| Look::WordEndAscii
| Look::WordStartUnicode
| Look::WordEndUnicode
| Look::WordStartHalfAscii
| Look::WordEndHalfAscii
| Look::WordStartHalfUnicode
| Look::WordEndHalfUnicode => Some((
"word-boundary",
"remove the word-boundary assertion if the surrounding context \
is known, or accept Meta-tier cost",
)),
Look::StartLF | Look::EndLF | Look::StartCRLF | Look::EndCRLF => Some((
"multiline-anchor",
"split the haystack by line and use a non-multiline pattern, \
or accept Meta-tier cost",
)),
},
K::Repetition(r) => {
if r.max.is_none_or(|max| max != r.min) {
Some((
"unbounded-quantifier",
"bound the quantifier (e.g. {1,64}) so literal extraction \
can yield finite exact prefixes, or accept Meta-tier cost",
))
} else {
hard_reject(&r.sub)
}
}
K::Capture(c) => {
hard_reject(&c.sub)
}
K::Class(cls) => {
if class_byte_count(cls).is_none() {
Some((
"unicode-class",
"if ASCII-only is acceptable, use (?-u) to disable unicode \
classes, or accept Meta-tier cost",
))
} else {
None
}
}
K::Concat(parts) | K::Alternation(parts) => parts.iter().find_map(hard_reject),
K::Empty | K::Literal(_) => None,
}
}
fn try_simple_shape(hir: &Hir) -> Option<ShapeOp> {
match hir.kind() {
HirKind::Literal(lit) => simple_literal(&lit.0),
HirKind::Class(cls) => simple_class(cls),
HirKind::Concat(parts) => anchored_literal(parts),
HirKind::Capture(c) => try_simple_shape(&c.sub),
_ => None,
}
}
fn simple_literal(bytes: &[u8]) -> Option<ShapeOp> {
match bytes.len() {
0 => None,
1 => Some(ShapeOp::ContainsByte(bytes[0])),
_ => Some(ShapeOp::Contains(Box::new(
memchr::memmem::Finder::new(bytes).into_owned(),
))),
}
}
fn simple_class(cls: &Class) -> Option<ShapeOp> {
let bytes = collect_class_bytes(cls)?;
match bytes.len() {
0 => None,
1 => Some(ShapeOp::ContainsByte(bytes[0])),
n if n <= 3 => Some(ShapeOp::ByteSet(SmallByteSet::new(&bytes))),
_ => None,
}
}
fn anchored_literal(parts: &[Hir]) -> Option<ShapeOp> {
let (anchor_start, body, anchor_end) = strip_anchors(parts)?;
let body_bytes = collect_literal_bytes(body)?;
if body_bytes.is_empty() {
return None;
}
match (anchor_start, anchor_end) {
(true, true) => Some(ShapeOp::ExactMatch(body_bytes.into_boxed_slice())),
(true, false) => Some(ShapeOp::StartsWith(body_bytes.into_boxed_slice())),
(false, true) => Some(ShapeOp::EndsWith(body_bytes.into_boxed_slice())),
(false, false) => simple_literal(&body_bytes),
}
}
fn strip_anchors(parts: &[Hir]) -> Option<(bool, &[Hir], bool)> {
let (mut lo, mut hi) = (0_usize, parts.len());
let start = parts
.first()
.is_some_and(|p| matches!(p.kind(), HirKind::Look(Look::Start)));
if start {
lo += 1;
}
let end = parts
.last()
.is_some_and(|p| matches!(p.kind(), HirKind::Look(Look::End)));
if end {
hi -= 1;
}
if hi <= lo {
return None;
}
Some((start, &parts[lo..hi], end))
}
fn collect_literal_bytes(parts: &[Hir]) -> Option<Vec<u8>> {
if parts.is_empty() {
return None;
}
let mut out: Vec<u8> = Vec::new();
for p in parts {
match p.kind() {
HirKind::Literal(Literal(bytes)) => out.extend_from_slice(bytes),
HirKind::Capture(c) => {
let inner = capture_literal_bytes(c)?;
out.extend_from_slice(&inner);
}
_ => return None,
}
}
Some(out)
}
fn capture_literal_bytes(c: ®ex_syntax::hir::Capture) -> Option<Vec<u8>> {
match c.sub.kind() {
HirKind::Literal(Literal(bytes)) => Some(bytes.to_vec()),
HirKind::Concat(parts) => collect_literal_bytes(parts),
_ => None,
}
}
fn class_byte_count(cls: &Class) -> Option<usize> {
let bytes = collect_class_bytes(cls)?;
Some(bytes.len())
}
fn collect_class_bytes(cls: &Class) -> Option<Vec<u8>> {
match cls {
Class::Bytes(b) => {
let mut out = Vec::new();
for r in b.iter() {
for byte in r.start()..=r.end() {
out.push(byte);
if out.len() > CLASS_BYTE_CAP {
return None;
}
}
}
Some(out)
}
Class::Unicode(u) => {
let mut out = Vec::new();
for r in u.iter() {
let start = r.start() as u32;
let end = r.end() as u32;
if end > 0x7F {
return None;
}
for cp in start..=end {
out.push(u8::try_from(cp).expect("ASCII range checked above"));
if out.len() > CLASS_BYTE_CAP {
return None;
}
}
}
Some(out)
}
}
}
fn try_literal_alternation(
hir: &Hir,
case_insensitive: bool,
) -> Option<(AhoCorasick, Anchor, Vec<Vec<u8>>)> {
let (outer_anchor, inner) = strip_outer_concat_anchors(hir)?;
let HirKind::Alternation(branches) = inner.kind() else {
return None;
};
if branches.is_empty() {
return None;
}
let mut literals: Vec<Vec<u8>> = Vec::with_capacity(branches.len());
for branch in branches {
let body_bytes = match branch.kind() {
HirKind::Literal(Literal(b)) => b.to_vec(),
HirKind::Concat(parts) => collect_literal_bytes(parts)?,
HirKind::Capture(c) => capture_literal_bytes(c)?,
_ => return None,
};
if body_bytes.is_empty() {
return None;
}
literals.push(body_bytes);
}
if literals.len() < 2 || literals.len() > LITERAL_SET_CAP {
return None;
}
let ac = AhoCorasick::builder()
.match_kind(MatchKind::LeftmostFirst)
.ascii_case_insensitive(case_insensitive)
.build(&literals)
.ok()?;
Some((ac, outer_anchor, literals))
}
fn strip_outer_concat_anchors(hir: &Hir) -> Option<(Anchor, &Hir)> {
let HirKind::Concat(parts) = hir.kind() else {
return Some((Anchor::Anywhere, hir));
};
let (start, body, end) = strip_anchors(parts)?;
let anchor = match (start, end) {
(true, true) => Anchor::Exact,
(true, false) => Anchor::AtStart,
(false, true) => Anchor::AtEnd,
(false, false) => Anchor::Anywhere,
};
if body.len() == 1 {
Some((anchor, &body[0]))
} else {
None
}
}
fn try_single_literal_for_case_insensitive(
hir: &Hir,
) -> Option<(AhoCorasick, Anchor, Vec<Vec<u8>>)> {
let (anchor, inner) = strip_outer_concat_anchors(hir)?;
let body_bytes = match inner.kind() {
HirKind::Literal(Literal(b)) => b.to_vec(),
HirKind::Concat(parts) => collect_literal_bytes(parts)?,
HirKind::Capture(c) => capture_literal_bytes(c)?,
_ => return None,
};
if body_bytes.is_empty() {
return None;
}
let ac = AhoCorasick::builder()
.match_kind(MatchKind::LeftmostFirst)
.ascii_case_insensitive(true)
.build([&body_bytes])
.ok()?;
Some((ac, anchor, vec![body_bytes]))
}
fn try_extractor_fallback(
hir: &Hir,
case_insensitive: bool,
) -> Option<(AhoCorasick, Anchor, Vec<Vec<u8>>)> {
let mut ex = Extractor::new();
ex.kind(ExtractKind::Prefix);
ex.limit_class(CLASS_BYTE_CAP);
ex.limit_total(LITERAL_SET_CAP);
let seq = ex.extract(hir);
if !seq.is_exact() {
return None;
}
let literals = seq.literals()?;
if literals.len() < 2 || literals.len() > LITERAL_SET_CAP {
return None;
}
let owned: Vec<Vec<u8>> = literals.iter().map(|l| l.as_bytes().to_vec()).collect();
let bytes: Vec<&[u8]> = owned.iter().map(Vec::as_slice).collect();
let ac = AhoCorasick::builder()
.match_kind(MatchKind::LeftmostFirst)
.ascii_case_insensitive(case_insensitive)
.build(&bytes)
.ok()?;
Some((ac, Anchor::Anywhere, owned))
}
fn meta_plan(pattern: &str, reason: &'static str, hint: &'static str) -> Classified {
use regex_automata::meta;
let meta = meta::Regex::new(pattern)
.expect("meta::Regex::new should not fail after regex_syntax::parse succeeded");
Classified {
plan: Plan::Meta(Box::new(meta)),
descriptor: Descriptor {
tier: MatcherTier::Regex,
reason,
hint,
},
}
}
fn shape_reason(shape: &ShapeOp) -> &'static str {
match shape {
ShapeOp::ContainsByte(_) => "shape:contains-byte",
ShapeOp::ByteSet(_) => "shape:byte-set",
ShapeOp::Contains(_) => "shape:contains-literal",
ShapeOp::StartsWith(_) => "shape:starts-with",
ShapeOp::EndsWith(_) => "shape:ends-with",
ShapeOp::ExactMatch(_) => "shape:exact-match",
}
}
fn tier_for_shape(shape: &ShapeOp) -> MatcherTier {
match shape {
ShapeOp::ContainsByte(_) | ShapeOp::ByteSet(_) => MatcherTier::Byte,
ShapeOp::StartsWith(b) | ShapeOp::EndsWith(b) | ShapeOp::ExactMatch(b) if b.len() == 1 => {
MatcherTier::Byte
}
ShapeOp::Contains(_)
| ShapeOp::StartsWith(_)
| ShapeOp::EndsWith(_)
| ShapeOp::ExactMatch(_) => MatcherTier::Literal,
}
}
fn literals_to_boxed(lits: Vec<Vec<u8>>) -> Box<[Box<[u8]>]> {
lits.into_iter()
.map(Vec::into_boxed_slice)
.collect::<Vec<_>>()
.into_boxed_slice()
}