use memchr::memmem;
use regex::Regex;
use std::sync::RwLock;
use crate::common::StringMap;
const MAX_CACHE_SIZE: usize = 10_000;
#[inline]
fn contains_case_insensitive(haystack: &str, needle: &str) -> bool {
if needle.is_empty() {
return true;
}
if needle.len() > haystack.len() {
return false;
}
let needle_bytes = needle.as_bytes();
let haystack_bytes = haystack.as_bytes();
'outer: for i in 0..=(haystack_bytes.len() - needle_bytes.len()) {
for j in 0..needle_bytes.len() {
if !haystack_bytes[i + j].eq_ignore_ascii_case(&needle_bytes[j]) {
continue 'outer;
}
}
return true;
}
false
}
#[derive(Debug, Clone)]
pub enum CompiledPattern {
Exact(String),
Prefix(String),
Suffix(String),
Contains(String),
PrefixSuffix(String, String),
Regex(Regex),
MatchAll,
SingleChar,
}
impl CompiledPattern {
#[inline]
pub fn matches(&self, text: &str) -> bool {
match self {
CompiledPattern::MatchAll => true,
CompiledPattern::SingleChar => text.len() == 1,
CompiledPattern::Exact(s) => text == s,
CompiledPattern::Prefix(p) => text.starts_with(p),
CompiledPattern::Suffix(s) => text.ends_with(s),
CompiledPattern::Contains(c) => memmem::find(text.as_bytes(), c.as_bytes()).is_some(),
CompiledPattern::PrefixSuffix(p, s) => {
text.starts_with(p) && text.ends_with(s) && text.len() >= p.len() + s.len()
}
CompiledPattern::Regex(re) => re.is_match(text),
}
}
#[inline]
pub fn matches_insensitive(&self, text: &str) -> bool {
match self {
CompiledPattern::MatchAll => true,
CompiledPattern::SingleChar => text.len() == 1,
CompiledPattern::Exact(s) => text.eq_ignore_ascii_case(s),
CompiledPattern::Prefix(p) => {
text.len() >= p.len() && text[..p.len()].eq_ignore_ascii_case(p)
}
CompiledPattern::Suffix(s) => {
text.len() >= s.len() && text[text.len() - s.len()..].eq_ignore_ascii_case(s)
}
CompiledPattern::Contains(c) => contains_case_insensitive(text, c),
CompiledPattern::PrefixSuffix(p, s) => {
if text.len() < p.len() + s.len() {
return false;
}
text[..p.len()].eq_ignore_ascii_case(p)
&& text[text.len() - s.len()..].eq_ignore_ascii_case(s)
}
CompiledPattern::Regex(re) => {
re.is_match(text)
}
}
}
}
struct CacheEntry {
pattern: CompiledPattern,
case_insensitive_pattern: Option<CompiledPattern>,
}
pub struct PatternCache {
cache: RwLock<StringMap<CacheEntry>>,
}
impl PatternCache {
pub fn new() -> Self {
Self {
cache: RwLock::new(StringMap::new()),
}
}
pub fn get_or_compile(&self, pattern: &str) -> CompiledPattern {
if let Ok(cache) = self.cache.read() {
if let Some(entry) = cache.get(pattern) {
return entry.pattern.clone();
}
}
let compiled = compile_pattern(pattern, false);
if let Ok(mut cache) = self.cache.write() {
if cache.len() >= MAX_CACHE_SIZE {
let keys: Vec<_> = cache.keys().take(MAX_CACHE_SIZE / 2).cloned().collect();
for key in keys {
cache.remove(&key);
}
}
cache.insert(
pattern.to_string(),
CacheEntry {
pattern: compiled.clone(),
case_insensitive_pattern: None,
},
);
}
compiled
}
pub fn get_or_compile_insensitive(&self, pattern: &str) -> CompiledPattern {
if let Ok(cache) = self.cache.read() {
if let Some(entry) = cache.get(pattern) {
if let Some(ref ci_pattern) = entry.case_insensitive_pattern {
return ci_pattern.clone();
}
}
}
let compiled = compile_pattern(pattern, true);
if let Ok(mut cache) = self.cache.write() {
if let Some(entry) = cache.get_mut(pattern) {
entry.case_insensitive_pattern = Some(compiled.clone());
} else {
if cache.len() >= MAX_CACHE_SIZE {
let keys: Vec<_> = cache.keys().take(MAX_CACHE_SIZE / 2).cloned().collect();
for key in keys {
cache.remove(&key);
}
}
cache.insert(
pattern.to_string(),
CacheEntry {
pattern: compile_pattern(pattern, false),
case_insensitive_pattern: Some(compiled.clone()),
},
);
}
}
compiled
}
pub fn clear(&self) {
if let Ok(mut cache) = self.cache.write() {
cache.clear();
}
}
pub fn size(&self) -> usize {
self.cache.read().map(|c| c.len()).unwrap_or(0)
}
}
impl Default for PatternCache {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_CACHE: std::sync::OnceLock<PatternCache> = std::sync::OnceLock::new();
pub fn global_pattern_cache() -> &'static PatternCache {
GLOBAL_CACHE.get_or_init(PatternCache::new)
}
fn compile_pattern(pattern: &str, case_insensitive: bool) -> CompiledPattern {
if pattern.is_empty() {
return CompiledPattern::Exact(String::new());
}
if pattern == "%" {
return CompiledPattern::MatchAll;
}
if pattern == "_" {
return CompiledPattern::SingleChar;
}
let has_percent = pattern.contains('%');
let has_underscore = pattern.contains('_');
if !has_percent && !has_underscore {
return CompiledPattern::Exact(pattern.to_string());
}
if !has_underscore {
let parts: Vec<&str> = pattern.split('%').collect();
match parts.as_slice() {
["", suffix] if !suffix.is_empty() => {
return CompiledPattern::Suffix(suffix.to_string());
}
[prefix, ""] if !prefix.is_empty() => {
return CompiledPattern::Prefix(prefix.to_string());
}
["", contains, ""] if !contains.is_empty() => {
return CompiledPattern::Contains(contains.to_string());
}
[prefix, suffix] if !prefix.is_empty() && !suffix.is_empty() => {
return CompiledPattern::PrefixSuffix(prefix.to_string(), suffix.to_string());
}
_ => {}
}
}
let regex_pattern = like_to_regex(pattern, case_insensitive);
match Regex::new(®ex_pattern) {
Ok(re) => CompiledPattern::Regex(re),
Err(_) => {
CompiledPattern::Exact(pattern.to_string())
}
}
}
fn like_to_regex(pattern: &str, case_insensitive: bool) -> String {
let mut regex = String::with_capacity(pattern.len() * 2 + 4);
if case_insensitive {
regex.push_str("(?i)");
}
regex.push('^');
let mut chars = pattern.chars().peekable();
while let Some(c) = chars.next() {
match c {
'%' => regex.push_str(".*"),
'_' => regex.push('.'),
'\\' => {
if let Some(next) = chars.next() {
regex.push('\\');
regex.push(next);
}
}
'.' | '^' | '$' | '*' | '+' | '?' | '{' | '}' | '[' | ']' | '(' | ')' | '|' => {
regex.push('\\');
regex.push(c);
}
_ => regex.push(c),
}
}
regex.push('$');
regex
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_match() {
let pattern = compile_pattern("hello", false);
assert!(pattern.matches("hello"));
assert!(!pattern.matches("Hello"));
assert!(!pattern.matches("hello world"));
}
#[test]
fn test_prefix_match() {
let pattern = compile_pattern("hello%", false);
assert!(pattern.matches("hello"));
assert!(pattern.matches("hello world"));
assert!(!pattern.matches("say hello"));
}
#[test]
fn test_suffix_match() {
let pattern = compile_pattern("%world", false);
assert!(pattern.matches("world"));
assert!(pattern.matches("hello world"));
assert!(!pattern.matches("world hello"));
}
#[test]
fn test_contains_match() {
let pattern = compile_pattern("%ell%", false);
assert!(pattern.matches("hello"));
assert!(pattern.matches("yell"));
assert!(pattern.matches("well done"));
assert!(!pattern.matches("hallo"));
}
#[test]
fn test_prefix_suffix_match() {
let pattern = compile_pattern("hello%world", false);
assert!(pattern.matches("helloworld"));
assert!(pattern.matches("hello big world"));
assert!(!pattern.matches("hello"));
assert!(!pattern.matches("world"));
}
#[test]
fn test_match_all() {
let pattern = compile_pattern("%", false);
assert!(pattern.matches(""));
assert!(pattern.matches("anything"));
}
#[test]
fn test_single_char() {
let pattern = compile_pattern("_", false);
assert!(pattern.matches("a"));
assert!(pattern.matches("Z"));
assert!(!pattern.matches(""));
assert!(!pattern.matches("ab"));
}
#[test]
fn test_complex_pattern() {
let pattern = compile_pattern("h_llo%", false);
assert!(pattern.matches("hello"));
assert!(pattern.matches("hallo world"));
assert!(!pattern.matches("hllo"));
}
#[test]
fn test_case_insensitive() {
let pattern = compile_pattern("hello%", false);
assert!(pattern.matches_insensitive("Hello World"));
assert!(pattern.matches_insensitive("HELLO"));
assert!(!pattern.matches_insensitive("say hello"));
}
#[test]
fn test_global_cache() {
let cache = global_pattern_cache();
let p1 = cache.get_or_compile("test%");
assert!(p1.matches("testing"));
let p2 = cache.get_or_compile("test%");
assert!(p2.matches("testing"));
assert!(cache.size() >= 1);
}
#[test]
fn test_cache_insensitive() {
let cache = global_pattern_cache();
let p = cache.get_or_compile_insensitive("%Test%");
assert!(p.matches_insensitive("testing"));
assert!(p.matches_insensitive("TEST"));
}
}