use crate::utils::nlp::NamedEntityRecognizer;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, info};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum EntityType {
Person,
Organization,
Location,
DateTime,
Number,
URL,
Email,
RDFResource,
Property,
Class,
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedEntity {
pub text: String,
pub entity_type: EntityType,
pub start: usize,
pub end: usize,
pub confidence: f32,
pub resolved_uri: Option<String>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EntityExtractionConfig {
pub extract_persons: bool,
pub extract_organizations: bool,
pub extract_locations: bool,
pub extract_datetime: bool,
pub extract_rdf_resources: bool,
pub min_confidence: f32,
pub enable_linking: bool,
}
impl Default for EntityExtractionConfig {
fn default() -> Self {
Self {
extract_persons: true,
extract_organizations: true,
extract_locations: true,
extract_datetime: true,
extract_rdf_resources: true,
min_confidence: 0.6,
enable_linking: true,
}
}
}
pub struct EntityExtractor {
config: EntityExtractionConfig,
ner_model: Option<NamedEntityRecognizer>,
patterns: HashMap<EntityType, Vec<regex::Regex>>,
}
impl EntityExtractor {
pub fn new(config: EntityExtractionConfig) -> Result<Self> {
let patterns = Self::build_patterns()?;
info!(
"Initialized entity extractor with {} pattern types",
patterns.len()
);
Ok(Self {
config,
ner_model: None,
patterns,
})
}
fn build_patterns() -> Result<HashMap<EntityType, Vec<regex::Regex>>> {
let mut patterns = HashMap::new();
patterns.insert(
EntityType::URL,
vec![regex::Regex::new(r"https?://[^\s]+")?],
);
patterns.insert(
EntityType::Email,
vec![regex::Regex::new(
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
)?],
);
patterns.insert(
EntityType::Number,
vec![
regex::Regex::new(r"\b\d+\b")?,
regex::Regex::new(r"\b\d+\.\d+\b")?,
],
);
patterns.insert(
EntityType::DateTime,
vec![
regex::Regex::new(r"\b\d{4}-\d{2}-\d{2}\b")?, regex::Regex::new(r"\b\d{1,2}/\d{1,2}/\d{4}\b")?, ],
);
patterns.insert(
EntityType::RDFResource,
vec![
regex::Regex::new(r"<[^>]+>")?, regex::Regex::new(r"[a-z]+:[A-Za-z0-9_-]+")?, ],
);
Ok(patterns)
}
pub fn extract(&self, text: &str) -> Result<Vec<ExtractedEntity>> {
debug!(
"Extracting entities from: {}",
text.chars().take(100).collect::<String>()
);
let mut entities = Vec::new();
for (entity_type, regexes) in &self.patterns {
for regex in regexes {
for capture in regex.find_iter(text) {
let entity_text = capture.as_str().to_string();
let start = capture.start();
let end = capture.end();
entities.push(ExtractedEntity {
text: entity_text.clone(),
entity_type: *entity_type,
start,
end,
confidence: 0.9, resolved_uri: self.resolve_uri(&entity_text, *entity_type),
metadata: HashMap::new(),
});
}
}
}
if self.config.extract_persons || self.config.extract_organizations {
entities.extend(self.extract_capitalized_entities(text));
}
entities.retain(|e| e.confidence >= self.config.min_confidence);
entities.sort_by_key(|e| e.start);
debug!("Extracted {} entities", entities.len());
Ok(entities)
}
fn extract_capitalized_entities(&self, text: &str) -> Vec<ExtractedEntity> {
let mut entities = Vec::new();
let words: Vec<&str> = text.split_whitespace().collect();
let mut pos = 0;
for word in words {
let start = text[pos..].find(word).map(|p| p + pos).unwrap_or(pos);
let end = start + word.len();
pos = end;
if word
.chars()
.next()
.map(|c| c.is_uppercase())
.unwrap_or(false)
&& word.len() > 2
&& start > 0
{
let entity_type =
if word.ends_with("Inc") || word.ends_with("Corp") || word.ends_with("Ltd") {
EntityType::Organization
} else {
EntityType::Person
};
entities.push(ExtractedEntity {
text: word.to_string(),
entity_type,
start,
end,
confidence: 0.6, resolved_uri: None,
metadata: HashMap::new(),
});
}
}
entities
}
fn resolve_uri(&self, text: &str, entity_type: EntityType) -> Option<String> {
if !self.config.enable_linking {
return None;
}
match entity_type {
EntityType::RDFResource => {
if text.starts_with('<') && text.ends_with('>') {
Some(text[1..text.len() - 1].to_string())
} else if text.contains(':') {
Some(format!("http://example.org/{}", text))
} else {
None
}
}
_ => None,
}
}
pub fn extract_with_relations(
&self,
text: &str,
) -> Result<HashMap<EntityType, Vec<ExtractedEntity>>> {
let entities = self.extract(text)?;
let mut grouped: HashMap<EntityType, Vec<ExtractedEntity>> = HashMap::new();
for entity in entities {
grouped.entry(entity.entity_type).or_default().push(entity);
}
Ok(grouped)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_url_extraction() {
let extractor =
EntityExtractor::new(EntityExtractionConfig::default()).expect("should succeed");
let entities = extractor
.extract("Check out https://example.org for more info")
.expect("should succeed");
assert!(entities.iter().any(|e| e.entity_type == EntityType::URL));
}
#[test]
fn test_email_extraction() {
let extractor =
EntityExtractor::new(EntityExtractionConfig::default()).expect("should succeed");
let entities = extractor
.extract("Contact us at support@example.com")
.expect("should succeed");
assert!(entities.iter().any(|e| e.entity_type == EntityType::Email));
}
#[test]
fn test_number_extraction() {
let extractor =
EntityExtractor::new(EntityExtractionConfig::default()).expect("should succeed");
let entities = extractor
.extract("There are 42 items in the database")
.expect("should succeed");
assert!(entities.iter().any(|e| e.entity_type == EntityType::Number));
}
#[test]
fn test_rdf_resource_extraction() {
let extractor =
EntityExtractor::new(EntityExtractionConfig::default()).expect("should succeed");
let entities = extractor
.extract("Query for schema:Person resources")
.expect("should succeed");
assert!(entities
.iter()
.any(|e| e.entity_type == EntityType::RDFResource));
}
}