use bstr::BString;
use regex_syntax::hir::{Hir, literal};
use std::collections::BTreeSet;
use std::convert::Infallible;
use std::mem;
pub trait Matcher {
fn visit(&self, visitor: &mut MatcherVisitor);
}
impl<M: Matcher> Matcher for &M {
fn visit(&self, visitor: &mut MatcherVisitor) {
M::visit(self, visitor);
}
}
#[derive(Debug)]
struct Frame {
and_literal_prefixes: Option<BTreeSet<BString>>,
or_literal_prefixes: Option<BTreeSet<BString>>,
}
impl Default for Frame {
fn default() -> Self {
Self {
and_literal_prefixes: None,
or_literal_prefixes: Some(BTreeSet::new()),
}
}
}
impl Frame {
fn finish(self) -> Option<BTreeSet<BString>> {
let Self {
mut or_literal_prefixes,
and_literal_prefixes,
} = self;
union_prefixes_limited(&mut or_literal_prefixes, and_literal_prefixes, 100);
let prefixes = or_literal_prefixes?;
if prefixes
.first()
.is_none_or(|shortest_prefix| shortest_prefix.is_empty())
{
return None;
}
Some(prefixes)
}
}
#[derive(Debug)]
pub struct MatcherVisitor {
frames: Vec<Frame>,
}
impl MatcherVisitor {
pub(crate) fn new() -> Self {
Self {
frames: vec![Frame::default()],
}
}
fn current_frame(&mut self) -> &mut Frame {
self.frames
.last_mut()
.expect("mismatched nesting calls to MatcherVisitor")
}
pub(crate) fn finish(&mut self) -> Option<BTreeSet<BString>> {
let Self { frames } = self;
let frame = match &mut frames[..] {
[only_frame] => mem::take(only_frame),
_ => {
frames.clear();
frames.push(Frame::default());
panic!("mismatched nesting calls to MatcherVisitor")
}
};
frame.finish()
}
pub fn visit_nested_start(&mut self) {
self.frames.push(Frame::default());
}
pub fn visit_nested_finish(&mut self) {
let frame = self
.frames
.pop()
.expect("every finish should match with a start");
let new_inner = frame.finish();
intersect_prefix_expansions(&mut self.current_frame().and_literal_prefixes, new_inner);
}
pub fn visit_or_in(&mut self) {
let frame = self.current_frame();
let new_and = frame.and_literal_prefixes.take();
union_prefixes_limited(&mut frame.or_literal_prefixes, new_and, 100);
}
pub fn visit_match_regex(&mut self, regex: &str) {
let Ok(hir) = regex_syntax::parse(regex) else {
return;
};
let current = &mut self.frames.last_mut().unwrap().and_literal_prefixes;
let new_prefixes = extract_prefixes(&hir);
intersect_prefix_expansions(current, new_prefixes);
}
pub fn visit_match_equals(&mut self, equals: &str) {
self.visit_match_starts_with(equals);
}
pub fn visit_match_starts_with(&mut self, prefix: &str) {
let new_prefixes = Some(BTreeSet::from([BString::from(prefix)]));
let current = &mut self.frames.last_mut().unwrap().and_literal_prefixes;
intersect_prefix_expansions(current, new_prefixes);
}
}
fn union_prefixes_limited(
lhs: &mut Option<BTreeSet<BString>>,
rhs: Option<BTreeSet<BString>>,
max_len: usize,
) {
let Some(lhs_inner) = lhs else {
return;
};
let Some(mut rhs_inner) = rhs else {
*lhs = None;
return;
};
if rhs_inner.len() > lhs_inner.len() {
mem::swap(lhs_inner, &mut rhs_inner);
}
let mut len = lhs_inner.len();
for v in rhs_inner {
let did_insert = lhs_inner.insert(v);
len += usize::from(did_insert);
if len > max_len {
*lhs = None;
return;
}
}
}
fn intersect_prefix_expansions_into(
dst: &mut BTreeSet<BString>,
lhs: &mut BTreeSet<BString>,
rhs: &mut BTreeSet<BString>,
) -> Option<Infallible> {
let mut l = lhs.pop_first()?;
let mut r = rhs.pop_first()?;
loop {
while l <= r {
if r.starts_with(&l) {
dst.insert(r);
r = rhs.pop_first()?;
} else {
l = lhs.pop_first()?;
}
}
if l.starts_with(&r) {
dst.insert(l);
l = lhs.pop_first()?;
} else {
r = rhs.pop_first()?;
}
}
}
fn intersect_prefix_expansions(
lhs: &mut Option<BTreeSet<BString>>,
rhs: Option<BTreeSet<BString>>,
) {
let Some(lhs) = lhs else {
*lhs = rhs;
return;
};
let Some(mut rhs) = rhs else {
return;
};
let mut result = BTreeSet::new();
_ = intersect_prefix_expansions_into(&mut result, lhs, &mut rhs);
*lhs = result;
}
fn extract_prefixes(hir: &Hir) -> Option<BTreeSet<BString>> {
if !hir
.properties()
.look_set_prefix()
.contains_anchor_haystack()
{
return None;
}
let seq = literal::Extractor::new().extract(hir);
seq.literals().map(|literals| {
literals
.iter()
.map(|lit| BString::from(lit.as_bytes()))
.collect::<BTreeSet<_>>()
})
}
#[cfg(test)]
mod tests {
use super::*;
struct SimpleMatcher(&'static str);
impl Matcher for SimpleMatcher {
fn visit(&self, visitor: &mut MatcherVisitor) {
visitor.visit_match_starts_with(self.0);
}
}
#[test]
fn test_matcher_by_reference() {
let matcher = SimpleMatcher("/api");
let mut visitor = MatcherVisitor::new();
#[expect(clippy::needless_borrow)]
(&matcher).visit(&mut visitor);
let prefixes = visitor.finish();
assert!(prefixes.is_some());
}
#[test]
fn test_visit_match_regex_anchored() {
let mut visitor = MatcherVisitor::new();
visitor.visit_match_regex(r"^/api/.*");
let prefixes = visitor.finish().unwrap();
assert_eq!(prefixes, make_set(&["/api/"]));
}
#[test]
fn test_visit_match_regex_unanchored() {
let mut visitor = MatcherVisitor::new();
visitor.visit_match_regex(r"/api/.*");
assert!(visitor.finish().is_none());
}
#[test]
fn test_visit_match_equals() {
let mut visitor = MatcherVisitor::new();
visitor.visit_match_equals("/api/users");
let prefixes = visitor.finish().unwrap();
assert_eq!(prefixes, make_set(&["/api/users"]));
}
#[test]
fn test_visit_nested() {
let mut visitor = MatcherVisitor::new();
visitor.visit_nested_start();
visitor.visit_match_starts_with("/api");
visitor.visit_nested_finish();
let prefixes = visitor.finish().unwrap();
assert_eq!(prefixes, make_set(&["/api"]));
}
#[test]
fn test_visit_or_in() {
let mut visitor = MatcherVisitor::new();
visitor.visit_match_starts_with("/v1");
visitor.visit_or_in();
visitor.visit_match_starts_with("/v2");
let prefixes = visitor.finish().unwrap();
assert_eq!(prefixes, make_set(&["/v1", "/v2"]));
}
#[test]
fn test_nested_with_or() {
let mut visitor = MatcherVisitor::new();
visitor.visit_match_starts_with("/v");
visitor.visit_nested_start();
visitor.visit_match_starts_with("/v1");
visitor.visit_or_in();
visitor.visit_match_starts_with("/v2");
visitor.visit_nested_finish();
let prefixes = visitor.finish().unwrap();
assert_eq!(prefixes, make_set(&["/v1", "/v2"]));
}
#[test]
fn test_union_prefixes_limited_exceeds_max() {
let mut lhs = Some(BTreeSet::new());
let mut rhs = BTreeSet::new();
for i in 0..60 {
lhs.as_mut().unwrap().insert(BString::new(vec![i]));
}
for i in 60..120 {
rhs.insert(BString::new(vec![i]));
}
union_prefixes_limited(&mut lhs, Some(rhs), 100);
assert!(lhs.is_none());
}
#[test]
fn test_intersect_prefix_expansions_both_none() {
let mut lhs = None;
let rhs = None;
intersect_prefix_expansions(&mut lhs, rhs);
assert!(lhs.is_none());
}
#[test]
fn test_intersect_prefix_expansions_lhs_none() {
let mut lhs = None;
let mut rhs = BTreeSet::new();
rhs.insert(BString::from("/api"));
intersect_prefix_expansions(&mut lhs, Some(rhs.clone()));
assert_eq!(lhs, Some(rhs));
}
#[test]
fn test_intersect_prefix_expansions_rhs_none() {
let mut lhs_set = BTreeSet::new();
lhs_set.insert(BString::from("/api"));
let mut lhs = Some(lhs_set.clone());
intersect_prefix_expansions(&mut lhs, None);
assert_eq!(lhs, Some(lhs_set));
}
#[test]
fn test_intersect_prefix_expansions_with_values() {
let mut lhs_set = BTreeSet::new();
lhs_set.insert(BString::from("/a"));
let mut rhs_set = BTreeSet::new();
rhs_set.insert(BString::from("/api"));
let mut lhs = Some(lhs_set);
intersect_prefix_expansions(&mut lhs, Some(rhs_set));
let result = lhs.unwrap();
assert_eq!(result, make_set(&["/api"]));
}
fn make_set(items: &[&str]) -> BTreeSet<BString> {
let mut result = BTreeSet::new();
for &item in items {
result.insert(BString::from(item));
}
result
}
fn run_intersect(lhs: &[&str], rhs: &[&str]) -> BTreeSet<BString> {
let mut dst = BTreeSet::new();
_ = intersect_prefix_expansions_into(&mut dst, &mut make_set(lhs), &mut make_set(rhs));
dst
}
#[test]
fn test_intersect_prefix_expansions_into_doc_example() {
let result = run_intersect(&["a", "box", "z"], &["ankle", "apple", "bo", "dog"]);
assert_eq!(result, make_set(&["ankle", "apple", "box"]));
}
#[test]
fn test_intersect_prefix_expansions_into_empty_lhs() {
let result = run_intersect(&[], &["abc"]);
assert!(result.is_empty());
}
#[test]
fn test_intersect_prefix_expansions_into_empty_rhs() {
let result = run_intersect(&["abc"], &[]);
assert!(result.is_empty());
}
#[test]
fn test_intersect_prefix_expansions_into_no_overlap() {
let result = run_intersect(&["abc"], &["xyz"]);
assert!(result.is_empty());
}
#[test]
fn test_intersect_prefix_expansions_into_exact_match() {
let result = run_intersect(&["abc"], &["abc"]);
assert_eq!(result, make_set(&["abc"]));
}
#[test]
fn test_intersect_prefix_expansions_into_lhs_prefix_of_rhs() {
let result = run_intersect(&["ab"], &["abcd"]);
assert_eq!(result, make_set(&["abcd"]));
}
#[test]
fn test_intersect_prefix_expansions_into_rhs_prefix_of_lhs() {
let result = run_intersect(&["abcd"], &["ab"]);
assert_eq!(result, make_set(&["abcd"]));
}
#[test]
fn test_intersect_prefix_expansions_into_one_to_many() {
let result = run_intersect(&["a"], &["aa", "ab", "ac"]);
assert_eq!(result, make_set(&["aa", "ab", "ac"]));
}
#[test]
fn test_intersect_prefix_expansions_into_nested_prefixes_on_one_side() {
let result = run_intersect(&["a", "aaaaa", "ba"], &["aaab", "ba"]);
assert_eq!(result, make_set(&["aaab", "ba"]));
}
#[test]
fn test_intersect_prefix_expansions_into_multiple_lhs_prefixes() {
let result = run_intersect(&["ab", "b"], &["a", "bcd"]);
assert_eq!(result, make_set(&["ab", "bcd"]));
}
#[test]
fn test_extract_prefixes_anchored() {
let hir = regex_syntax::parse(r"^/api/.*").unwrap();
let prefixes = extract_prefixes(&hir);
assert!(prefixes.is_some());
assert_eq!(prefixes.unwrap(), make_set(&["/api/"]));
}
#[test]
fn test_extract_prefixes_unanchored() {
let hir = regex_syntax::parse(r"/api/.*").unwrap();
let prefixes = extract_prefixes(&hir);
assert!(prefixes.is_none());
}
#[test]
fn test_extract_prefixes_alternation_with_literal_suffix() {
let hir = regex_syntax::parse(r"^(a|b)123[^/]*").unwrap();
assert_eq!(extract_prefixes(&hir).unwrap(), make_set(&["a123", "b123"]));
}
#[test]
fn test_extract_prefixes_character_class_with_literal_suffix() {
let hir = regex_syntax::parse(r"^[a-c]123.*").unwrap();
assert_eq!(
extract_prefixes(&hir).unwrap(),
make_set(&["a123", "b123", "c123"])
);
}
#[test]
fn test_extract_prefixes_multiline_anchor() {
let hir = regex_syntax::parse(r"(?m)^foo").unwrap();
assert!(extract_prefixes(&hir).is_none());
}
#[test]
fn test_extract_prefixes_explicit_haystack_anchor() {
let hir = regex_syntax::parse(r"\Afoo").unwrap();
assert_eq!(extract_prefixes(&hir).unwrap(), make_set(&["foo"]));
}
#[test]
fn test_visit_match_regex_bare_anchor_yields_no_prefixes() {
let mut visitor = MatcherVisitor::new();
visitor.visit_match_regex("^");
assert!(visitor.finish().is_none());
}
#[test]
fn test_extract_prefixes_bare_anchor() {
let hir = regex_syntax::parse(r"^").unwrap();
assert_eq!(extract_prefixes(&hir), Some(make_set(&[""])));
}
#[test]
fn test_extract_prefixes_anchored_alternation() {
let hir = regex_syntax::parse(r"^(foo|bar)").unwrap();
let prefixes = extract_prefixes(&hir).unwrap();
assert_eq!(prefixes, make_set(&["bar", "foo"]));
}
#[test]
fn test_visit_match_regex_invalid_is_ignored() {
let mut visitor = MatcherVisitor::new();
visitor.visit_match_starts_with("/api");
visitor.visit_match_regex("[invalid");
let prefixes = visitor.finish().unwrap();
assert_eq!(prefixes, make_set(&["/api"]));
}
#[test]
fn test_extract_prefixes_literal_after_wildcard_not_extracted() {
let hir = regex_syntax::parse(r"^a.*abc123456").unwrap();
let prefixes = extract_prefixes(&hir).unwrap();
assert_eq!(prefixes, make_set(&["a"]));
}
#[test]
fn test_extract_prefixes_dot_prefix_not_extractable() {
let hir = regex_syntax::parse(r"^.abc1234").unwrap();
assert!(extract_prefixes(&hir).is_none());
}
#[test]
#[should_panic = "mismatched nesting calls to MatcherVisitor"]
fn test_unbalanced_nesting_extra_start() {
let mut visitor = MatcherVisitor::new();
visitor.visit_nested_start();
visitor.visit_match_starts_with("/api");
visitor.finish();
}
#[test]
#[should_panic = "mismatched nesting calls to MatcherVisitor"]
fn test_unbalanced_nesting_extra_finish() {
let mut visitor = MatcherVisitor::new();
visitor.visit_match_starts_with("/api");
visitor.visit_nested_finish();
}
}