use regex::Regex;
use crate::error::{Result, TextError};
#[derive(Debug, Clone)]
pub struct Token {
pub text: String,
pub pos: Option<String>,
pub lemma: Option<String>,
}
impl Token {
pub fn new(text: impl Into<String>) -> Token {
Token {
text: text.into(),
pos: None,
lemma: None,
}
}
pub fn with_pos(mut self, pos: impl Into<String>) -> Token {
self.pos = Some(pos.into());
self
}
pub fn with_lemma(mut self, lemma: impl Into<String>) -> Token {
self.lemma = Some(lemma.into());
self
}
}
#[derive(Debug, Clone)]
pub enum PatternElement {
Literal(String),
PoS(String),
Regex(String),
Any,
Gap {
min: usize,
max: usize,
},
}
#[derive(Debug, Clone)]
pub struct Pattern {
pub template: Vec<PatternElement>,
}
impl Pattern {
pub fn new(template: Vec<PatternElement>) -> Pattern {
Pattern { template }
}
}
#[derive(Debug, Clone)]
pub struct Match {
pub pattern_name: String,
pub start: usize,
pub end: usize,
pub groups: Vec<String>,
}
#[derive(Default)]
pub struct PatternMatcher {
patterns: Vec<(String, Pattern)>,
regex_cache: std::collections::HashMap<String, Regex>,
}
impl PatternMatcher {
pub fn new() -> PatternMatcher {
PatternMatcher::default()
}
pub fn add_pattern(&mut self, name: impl Into<String>, pattern: Pattern) -> Result<()> {
for elem in &pattern.template {
if let PatternElement::Regex(re_str) = elem {
if !self.regex_cache.contains_key(re_str) {
let compiled = Regex::new(re_str).map_err(|e| {
TextError::InvalidInput(format!("Bad regex '{}': {}", re_str, e))
})?;
self.regex_cache.insert(re_str.clone(), compiled);
}
}
}
self.patterns.push((name.into(), pattern));
Ok(())
}
pub fn match_all(&self, tokens: &[Token]) -> Vec<Match> {
let mut results = Vec::new();
for (name, pattern) in &self.patterns {
for start in 0..tokens.len() {
if let Some((end, groups)) = self.try_match(pattern, tokens, start) {
results.push(Match {
pattern_name: name.clone(),
start,
end,
groups,
});
}
}
}
results
}
fn try_match(
&self,
pattern: &Pattern,
tokens: &[Token],
start: usize,
) -> Option<(usize, Vec<String>)> {
self.try_match_from(pattern, tokens, start, 0, Vec::new())
}
fn try_match_from(
&self,
pattern: &Pattern,
tokens: &[Token],
pos: usize,
elem_idx: usize,
groups: Vec<String>,
) -> Option<(usize, Vec<String>)> {
if elem_idx >= pattern.template.len() {
return Some((pos, groups));
}
let elem = &pattern.template[elem_idx];
match elem {
PatternElement::Literal(s) => {
if pos >= tokens.len() {
return None;
}
if tokens[pos].text.to_lowercase() != s.to_lowercase() {
return None;
}
let mut new_groups = groups;
new_groups.push(tokens[pos].text.clone());
self.try_match_from(pattern, tokens, pos + 1, elem_idx + 1, new_groups)
}
PatternElement::PoS(tag) => {
if pos >= tokens.len() {
return None;
}
let tok_pos = tokens[pos].pos.as_deref().unwrap_or("");
if tok_pos != tag.as_str() {
return None;
}
let mut new_groups = groups;
new_groups.push(tokens[pos].text.clone());
self.try_match_from(pattern, tokens, pos + 1, elem_idx + 1, new_groups)
}
PatternElement::Regex(re_str) => {
if pos >= tokens.len() {
return None;
}
let re = self.regex_cache.get(re_str)?;
if !re.is_match(&tokens[pos].text) {
return None;
}
let mut new_groups = groups;
new_groups.push(tokens[pos].text.clone());
self.try_match_from(pattern, tokens, pos + 1, elem_idx + 1, new_groups)
}
PatternElement::Any => {
if pos >= tokens.len() {
return None;
}
let mut new_groups = groups;
new_groups.push(tokens[pos].text.clone());
self.try_match_from(pattern, tokens, pos + 1, elem_idx + 1, new_groups)
}
PatternElement::Gap { min, max } => {
for skip in *min..=*max {
let new_pos = pos + skip;
if new_pos > tokens.len() {
break;
}
if let Some(result) =
self.try_match_from(pattern, tokens, new_pos, elem_idx + 1, groups.clone())
{
return Some(result);
}
}
None
}
}
}
}
pub fn build_ner_pattern_matcher() -> Result<PatternMatcher> {
let mut matcher = PatternMatcher::new();
matcher.add_pattern(
"DATE",
Pattern::new(vec![PatternElement::Regex(
r"(?:(?:0?[1-9]|1[0-2])[\/\-](?:0?[1-9]|[12][0-9]|3[01])[\/\-](?:19|20)?\d{2}|(?:19|20)\d{2}[\/\-](?:0?[1-9]|1[0-2])[\/\-](?:0?[1-9]|[12][0-9]|3[01]))".to_string(),
)]),
)?;
matcher.add_pattern(
"MONEY",
Pattern::new(vec![PatternElement::Regex(
r"\$[0-9]+(?:\.[0-9]+)?".to_string(),
)]),
)?;
matcher.add_pattern(
"EMAIL",
Pattern::new(vec![PatternElement::Regex(
r"[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}".to_string(),
)]),
)?;
matcher.add_pattern(
"URL",
Pattern::new(vec![PatternElement::Regex(r"https?://[^\s]+".to_string())]),
)?;
matcher.add_pattern(
"PHONE",
Pattern::new(vec![PatternElement::Regex(
r"(?:\+?1[\-.\s]?)?\(?\d{3}\)?[\-.\s]\d{3}[\-.\s]\d{4}".to_string(),
)]),
)?;
Ok(matcher)
}
#[cfg(test)]
mod tests {
use super::*;
fn tokenize_simple(text: &str) -> Vec<Token> {
text.split_whitespace().map(Token::new).collect()
}
#[test]
fn test_literal_match() {
let mut matcher = PatternMatcher::new();
matcher
.add_pattern(
"greeting",
Pattern::new(vec![
PatternElement::Literal("hello".to_string()),
PatternElement::Literal("world".to_string()),
]),
)
.expect("add_pattern failed");
let tokens = tokenize_simple("hello world foo bar");
let matches = matcher.match_all(&tokens);
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].start, 0);
assert_eq!(matches[0].end, 2);
}
#[test]
fn test_pos_match() {
let mut matcher = PatternMatcher::new();
matcher
.add_pattern(
"dt_nn",
Pattern::new(vec![
PatternElement::PoS("DT".to_string()),
PatternElement::PoS("NN".to_string()),
]),
)
.expect("add_pattern failed");
let tokens = vec![
Token::new("the").with_pos("DT"),
Token::new("dog").with_pos("NN"),
Token::new("runs").with_pos("VBZ"),
];
let matches = matcher.match_all(&tokens);
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].groups, vec!["the", "dog"]);
}
#[test]
fn test_regex_match() {
let mut matcher = PatternMatcher::new();
matcher
.add_pattern(
"money",
Pattern::new(vec![PatternElement::Regex(
r"\$[0-9]+(?:\.[0-9]+)?".to_string(),
)]),
)
.expect("add_pattern failed");
let tokens = tokenize_simple("costs $29.99 shipping $5");
let matches = matcher.match_all(&tokens);
assert_eq!(matches.len(), 2);
}
#[test]
fn test_any_match() {
let mut matcher = PatternMatcher::new();
matcher
.add_pattern(
"any_word",
Pattern::new(vec![
PatternElement::Literal("the".to_string()),
PatternElement::Any,
]),
)
.expect("add_pattern failed");
let tokens = tokenize_simple("the cat sat on the mat");
let matches = matcher.match_all(&tokens);
assert!(matches.len() >= 2);
}
#[test]
fn test_gap_match() {
let mut matcher = PatternMatcher::new();
matcher
.add_pattern(
"verb_phrase",
Pattern::new(vec![
PatternElement::Literal("john".to_string()),
PatternElement::Gap { min: 0, max: 2 },
PatternElement::Literal("mary".to_string()),
]),
)
.expect("add_pattern failed");
let tokens = tokenize_simple("john loves mary");
let matches = matcher.match_all(&tokens);
assert!(!matches.is_empty());
}
#[test]
fn test_ner_patterns_email() {
let matcher = build_ner_pattern_matcher().expect("build failed");
let tokens = tokenize_simple("contact user@example.com for info");
let matches = matcher.match_all(&tokens);
assert!(matches.iter().any(|m| m.pattern_name == "EMAIL"));
}
#[test]
fn test_ner_patterns_money() {
let matcher = build_ner_pattern_matcher().expect("build failed");
let tokens = tokenize_simple("costs $100 today");
let matches = matcher.match_all(&tokens);
assert!(matches.iter().any(|m| m.pattern_name == "MONEY"));
}
#[test]
fn test_bad_regex_error() {
let mut matcher = PatternMatcher::new();
let result = matcher.add_pattern(
"bad",
Pattern::new(vec![PatternElement::Regex("[invalid".to_string())]),
);
assert!(result.is_err());
}
}