use std::{borrow::Cow, collections::HashSet};
use regex::{Regex, RegexSet, escape};
use rustc_hash::FxHashSet;
use serde::{Deserialize, Serialize};
use crate::{
matcher::{MatchResultTrait, TextMatcherTrait},
process::process_matcher::{
ProcessType, ProcessTypeBitNode, ProcessedTextMasks, build_process_type_tree,
reduce_text_process_with_tree,
},
};
#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum RegexMatchType {
SimilarChar,
Acrostic,
Regex,
}
#[derive(Debug, Clone)]
pub struct RegexTable<'a> {
pub table_id: u32,
pub match_id: u32,
pub process_type: ProcessType,
pub regex_match_type: RegexMatchType,
pub word_list: Vec<&'a str>,
}
#[derive(Debug, Clone)]
struct RegexConf {
table_id: u32,
match_id: u32,
process_type: ProcessType,
word_id: u32,
word: String,
}
#[derive(Debug, Clone)]
pub struct RegexResult<'a> {
pub match_id: u32,
pub table_id: u32,
pub word: Cow<'a, str>,
}
impl MatchResultTrait<'_> for RegexResult<'_> {
fn match_id(&self) -> u32 {
self.match_id
}
fn table_id(&self) -> u32 {
self.table_id
}
fn word_id(&self) -> u32 {
0
}
fn word(&self) -> &str {
&self.word
}
fn similarity(&self) -> Option<f64> {
None
}
}
#[derive(Debug, Clone)]
pub struct RegexMatcher {
process_type_tree: Box<[ProcessTypeBitNode]>,
regex_set: RegexSet,
regex_dedup_conf_list: Box<[RegexConf]>,
}
impl RegexMatcher {
pub fn new(regex_table_list: &[RegexTable]) -> RegexMatcher {
let mut process_type_set = HashSet::with_capacity(regex_table_list.len());
let mut regex_pattern_list = Vec::new();
let mut regex_conf_list: Vec<RegexConf> = Vec::new();
for regex_table in regex_table_list {
process_type_set.insert(regex_table.process_type.bits());
match regex_table.regex_match_type {
RegexMatchType::SimilarChar => {
let pattern = regex_table
.word_list
.iter()
.map(|charstr| format!("(?:{})", escape(charstr).replace(',', "|")))
.collect::<Vec<String>>()
.join(".?");
if pattern.len() > 1024 {
eprintln!(
"SimilarChar pattern is too long ({}), potential ReDoS risk. Skipping.",
pattern.len()
);
continue;
}
if Regex::new(&pattern).is_ok() {
regex_pattern_list.push(pattern.clone());
regex_conf_list.push(RegexConf {
table_id: regex_table.table_id,
match_id: regex_table.match_id,
process_type: regex_table.process_type,
word_id: 0,
word: pattern,
});
}
}
RegexMatchType::Acrostic => {
for (index, &word) in regex_table.word_list.iter().enumerate() {
let pattern = format!(
r"(?i)(?:^|[\s\pP]+?){}",
escape(word).replace(',', r".*?[\s\pP]+?")
);
if pattern.len() > 1024 {
eprintln!("Acrostic pattern too long for word {}, skipping.", word);
continue;
}
if Regex::new(&pattern).is_ok() {
regex_pattern_list.push(pattern);
regex_conf_list.push(RegexConf {
table_id: regex_table.table_id,
match_id: regex_table.match_id,
process_type: regex_table.process_type,
word_id: index as u32,
word: word.to_owned(),
});
} else {
eprintln!("Acrostic word {word} is illegal, ignored.");
}
}
}
RegexMatchType::Regex => {
for (index, &word) in regex_table.word_list.iter().enumerate() {
if word.len() > 1024 {
eprintln!("Regex pattern too long, skipping: {:.20}...", word);
continue;
}
if Regex::new(word).is_ok() {
regex_pattern_list.push(word.to_string());
regex_conf_list.push(RegexConf {
table_id: regex_table.table_id,
match_id: regex_table.match_id,
process_type: regex_table.process_type,
word_id: index as u32,
word: word.to_owned(),
});
} else {
eprintln!("Regex word {word} is illegal, ignored.");
}
}
}
};
}
let process_type_tree = build_process_type_tree(&process_type_set).into_boxed_slice();
let regex_set = RegexSet::new(®ex_pattern_list).unwrap_or_else(|e| {
eprintln!("Failed to compile regex set: {}", e);
RegexSet::empty()
});
RegexMatcher {
process_type_tree,
regex_set,
regex_dedup_conf_list: regex_conf_list.into_boxed_slice(),
}
}
}
impl<'a> TextMatcherTrait<'a, RegexResult<'a>> for RegexMatcher {
fn is_match(&'a self, text: &'a str) -> bool {
if text.is_empty() {
return false;
}
let processed_text_process_type_masks =
reduce_text_process_with_tree(&self.process_type_tree, text);
self.is_match_preprocessed(&processed_text_process_type_masks)
}
fn process(&'a self, text: &'a str) -> Vec<RegexResult<'a>> {
if text.is_empty() {
return Vec::new();
}
let processed_text_process_type_masks =
reduce_text_process_with_tree(&self.process_type_tree, text);
self.process_preprocessed(&processed_text_process_type_masks)
}
fn is_match_preprocessed(
&'a self,
processed_text_process_type_masks: &ProcessedTextMasks<'a>,
) -> bool {
for (processed_text, process_type_mask) in processed_text_process_type_masks {
for pattern_id in self.regex_set.matches(processed_text) {
let conf = &self.regex_dedup_conf_list[pattern_id];
if (process_type_mask & (1u64 << conf.process_type.bits())) != 0 {
return true;
}
}
}
false
}
fn process_preprocessed(
&'a self,
processed_text_process_type_masks: &ProcessedTextMasks<'a>,
) -> Vec<RegexResult<'a>> {
let mut result_list = Vec::new();
let mut table_id_index_set = FxHashSet::default();
for (processed_text, process_type_mask) in processed_text_process_type_masks {
for pattern_id in self.regex_set.matches(processed_text).iter() {
let conf = &self.regex_dedup_conf_list[pattern_id];
if (process_type_mask & (1u64 << conf.process_type.bits())) == 0 {
continue;
}
let table_id_index = ((conf.table_id as usize) << 32) | (conf.word_id as usize);
if table_id_index_set.insert(table_id_index) {
result_list.push(RegexResult {
match_id: conf.match_id,
table_id: conf.table_id,
word: Cow::Owned(conf.word.clone()),
});
}
}
}
result_list
}
}