use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct POSTagger {
dictionary: HashMap<String, String>,
default_tag: String,
}
impl POSTagger {
pub fn new() -> Self {
let mut dictionary = HashMap::new();
Self::add_common_words(&mut dictionary);
Self {
dictionary,
default_tag: "NN".to_string(), }
}
fn add_common_words(dict: &mut HashMap<String, String>) {
for word in &["i", "you", "he", "she", "it", "we", "they"] {
dict.insert(word.to_string(), "PRP".to_string());
}
for word in &["is", "are", "was", "were", "be", "been", "being"] {
dict.insert(word.to_string(), "VB".to_string());
}
for word in &["have", "has", "had"] {
dict.insert(word.to_string(), "VB".to_string());
}
for word in &["do", "does", "did"] {
dict.insert(word.to_string(), "VB".to_string());
}
for word in &["the", "a", "an"] {
dict.insert(word.to_string(), "DT".to_string());
}
for word in &["in", "on", "at", "by", "for", "with", "about", "from", "to"] {
dict.insert(word.to_string(), "IN".to_string());
}
for word in &["and", "or", "but"] {
dict.insert(word.to_string(), "CC".to_string());
}
for word in &["what", "who", "where", "when", "why", "how", "which"] {
dict.insert(word.to_string(), "WH".to_string());
}
}
pub fn tag(&self, tokens: &[String]) -> Vec<(String, String)> {
tokens
.iter()
.map(|token| {
let lower = token.to_lowercase();
let tag = self
.dictionary
.get(&lower)
.cloned()
.unwrap_or_else(|| self.infer_tag(&lower));
(token.clone(), tag)
})
.collect()
}
fn infer_tag(&self, word: &str) -> String {
if word.ends_with("ing") {
return "VBG".to_string(); }
if word.ends_with("ed") {
return "VBD".to_string(); }
if word.ends_with("ly") {
return "RB".to_string(); }
if word.ends_with("tion") || word.ends_with("ness") || word.ends_with("ment") {
return "NN".to_string(); }
if word.ends_with("s") && word.len() > 2 {
return "NNS".to_string(); }
if word.chars().next().is_some_and(|c| c.is_uppercase()) {
return "NNP".to_string(); }
self.default_tag.clone()
}
pub fn add_word(&mut self, word: String, tag: String) {
self.dictionary.insert(word.to_lowercase(), tag);
}
}
impl Default for POSTagger {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum EntityTag {
Person,
Organization,
Location,
Date,
Number,
Miscellaneous,
}
impl EntityTag {
pub fn as_str(&self) -> &'static str {
match self {
EntityTag::Person => "PERSON",
EntityTag::Organization => "ORG",
EntityTag::Location => "LOC",
EntityTag::Date => "DATE",
EntityTag::Number => "NUM",
EntityTag::Miscellaneous => "MISC",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NamedEntity {
pub text: String,
pub tag: EntityTag,
pub start: usize,
pub end: usize,
pub confidence: f64,
}
#[derive(Debug, Clone)]
pub struct NamedEntityRecognizer {
entities: HashMap<String, EntityTag>,
pos_tagger: POSTagger,
}
impl NamedEntityRecognizer {
pub fn new() -> Self {
let mut entities = HashMap::new();
for loc in &[
"america",
"europe",
"asia",
"africa",
"usa",
"uk",
"china",
"japan",
"germany",
"france",
"canada",
"australia",
"london",
"paris",
"tokyo",
"newyork",
"berlin",
] {
entities.insert(loc.to_string(), EntityTag::Location);
}
for org in &[
"university",
"company",
"corporation",
"institute",
"foundation",
"association",
] {
entities.insert(org.to_string(), EntityTag::Organization);
}
Self {
entities,
pos_tagger: POSTagger::new(),
}
}
pub fn recognize(&self, text: &str) -> Vec<NamedEntity> {
let mut result = Vec::new();
let tokens: Vec<&str> = text.split_whitespace().collect();
let mut pos = 0;
for token in tokens {
let start = text[pos..].find(token).map(|p| pos + p).unwrap_or(pos);
let end = start + token.len();
if let Some(entity) = self.classify_entity(token) {
result.push(NamedEntity {
text: token.to_string(),
tag: entity,
start,
end,
confidence: 0.8,
});
}
pos = end;
}
result
}
fn classify_entity(&self, token: &str) -> Option<EntityTag> {
let lower = token.to_lowercase();
if let Some(tag) = self.entities.get(&lower) {
return Some(*tag);
}
if token
.chars()
.all(|c| c.is_numeric() || c == ',' || c == '.')
{
return Some(EntityTag::Number);
}
if token.chars().next().is_some_and(|c| c.is_uppercase()) {
if token.ends_with("Inc") || token.ends_with("Corp") || token.ends_with("Ltd") {
return Some(EntityTag::Organization);
}
return Some(EntityTag::Person);
}
if self.is_date_like(token) {
return Some(EntityTag::Date);
}
None
}
fn is_date_like(&self, token: &str) -> bool {
let months = [
"january",
"february",
"march",
"april",
"may",
"june",
"july",
"august",
"september",
"october",
"november",
"december",
"jan",
"feb",
"mar",
"apr",
"may",
"jun",
"jul",
"aug",
"sep",
"oct",
"nov",
"dec",
];
let lower = token.to_lowercase();
months.contains(&lower.as_str())
|| token.contains('/')
|| token.contains('-')
|| (token.len() == 4 && token.chars().all(|c| c.is_numeric()))
}
pub fn add_entity(&mut self, text: String, tag: EntityTag) {
self.entities.insert(text.to_lowercase(), tag);
}
}
impl Default for NamedEntityRecognizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct Entity {
pub text: String,
pub start: usize,
pub end: usize,
pub entity_type: String,
}
#[derive(Debug, Clone)]
pub struct RuleBasedNER {
inner: NamedEntityRecognizer,
}
impl RuleBasedNER {
pub fn new() -> Self {
Self {
inner: NamedEntityRecognizer::new(),
}
}
pub fn extract_entities(&self, text: &str) -> Result<Vec<Entity>, String> {
let named_entities = self.inner.recognize(text);
let entities = named_entities
.into_iter()
.map(|ne| Entity {
text: ne.text,
start: ne.start,
end: ne.end,
entity_type: ne.tag.as_str().to_string(),
})
.collect();
Ok(entities)
}
}
impl Default for RuleBasedNER {
fn default() -> Self {
Self::new()
}
}
pub trait Tokenizer {
fn tokenize(&self, text: &str) -> Result<Vec<String>, String>;
}
#[derive(Debug, Clone, Default)]
pub struct WordTokenizer;
impl WordTokenizer {
pub fn new() -> Self {
Self
}
}
impl Tokenizer for WordTokenizer {
fn tokenize(&self, text: &str) -> Result<Vec<String>, String> {
Ok(text.split_whitespace().map(|s| s.to_string()).collect())
}
}
#[derive(Debug, Clone)]
pub struct SentimentScore {
pub score: f32, pub magnitude: f32, }
#[derive(Debug, Clone)]
pub struct LexiconSentimentAnalyzer {
positive_words: HashMap<String, f32>,
negative_words: HashMap<String, f32>,
}
impl LexiconSentimentAnalyzer {
pub fn with_basiclexicon() -> Self {
let mut positive_words = HashMap::new();
let mut negative_words = HashMap::new();
for word in &[
"good",
"great",
"excellent",
"amazing",
"wonderful",
"love",
"best",
"perfect",
"happy",
"nice",
] {
positive_words.insert(word.to_string(), 1.0);
}
for word in &[
"bad", "terrible", "awful", "hate", "worst", "horrible", "poor", "sad", "wrong", "fail",
] {
negative_words.insert(word.to_string(), -1.0);
}
Self {
positive_words,
negative_words,
}
}
pub fn analyze(&self, text: &str) -> Result<SentimentScore, String> {
let lowercase = text.to_lowercase();
let words: Vec<&str> = lowercase.split_whitespace().collect();
let mut total_score = 0.0;
let mut count = 0;
for word in &words {
if let Some(&score) = self.positive_words.get(*word) {
total_score += score;
count += 1;
} else if let Some(&score) = self.negative_words.get(*word) {
total_score += score;
count += 1;
}
}
let score = if count > 0 {
total_score / words.len() as f32
} else {
0.0
};
let magnitude = score.abs();
Ok(SentimentScore { score, magnitude })
}
}
impl Default for LexiconSentimentAnalyzer {
fn default() -> Self {
Self::with_basiclexicon()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pos_tagger() {
let tagger = POSTagger::new();
let tokens = vec!["The".to_string(), "cat".to_string(), "runs".to_string()];
let tagged = tagger.tag(&tokens);
assert_eq!(tagged.len(), 3);
assert_eq!(tagged[0].1, "DT"); }
#[test]
fn test_ner() {
let ner = NamedEntityRecognizer::new();
let text = "John visited London in 2023";
let entities = ner.recognize(text);
assert!(!entities.is_empty());
assert!(entities.iter().any(|e| e.tag == EntityTag::Person));
assert!(entities.iter().any(|e| e.tag == EntityTag::Location));
}
#[test]
fn test_entity_tag_string() {
assert_eq!(EntityTag::Person.as_str(), "PERSON");
assert_eq!(EntityTag::Location.as_str(), "LOC");
assert_eq!(EntityTag::Organization.as_str(), "ORG");
}
#[test]
fn test_number_recognition() {
let ner = NamedEntityRecognizer::new();
let text = "The price is 123.45 dollars";
let entities = ner.recognize(text);
let numbers: Vec<_> = entities
.iter()
.filter(|e| e.tag == EntityTag::Number)
.collect();
assert!(!numbers.is_empty());
}
}